class Dataset: """Class that opens a single NDTiffStorage dataset""" _POSITION_AXIS = "position" _Z_AXIS = "z" _TIME_AXIS = "time" _CHANNEL_AXIS = "channel" def __init__(self, dataset_path=None, full_res_only=True, remote_storage=None): self._tile_width = None self._tile_height = None if remote_storage is not None: # this dataset is a view of an active acquisiiton. The storage exists on the java side self._remote_storage = remote_storage self._bridge = Bridge() smd = self._remote_storage.get_summary_metadata() if "GridPixelOverlapX" in smd.keys(): self._tile_width = smd["Width"] - smd["GridPixelOverlapX"] self._tile_height = smd["Height"] - smd["GridPixelOverlapY"] return else: self._remote_storage = None self.path = dataset_path res_dirs = [ dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI)) ] # map from downsample factor to datset self.res_levels = {} if "Full resolution" not in res_dirs: raise Exception( "Couldn't find full resolution directory. Is this the correct path to a dataset?" ) num_tiffs = 0 count = 0 for res_dir in res_dirs: for file in os.listdir(os.path.join(dataset_path, res_dir)): if file.endswith(".tif"): num_tiffs += 1 for res_dir in res_dirs: if full_res_only and res_dir != "Full resolution": continue res_dir_path = os.path.join(dataset_path, res_dir) res_level = _ResolutionLevel(res_dir_path, count, num_tiffs) if res_dir == "Full resolution": # TODO: might want to move this within the resolution level class to facilitate loading pyramids self.res_levels[0] = res_level # get summary metadata and index tree from full resolution image self.summary_metadata = res_level.reader_list[0].summary_md self.rgb = res_level.reader_list[0].rgb self._channel_names = {} # read them from image metadata self._extra_axes_to_storage_channel = {} # store some fields explicitly for easy access self.dtype = ( np.uint16 if self.summary_metadata["PixelType"] == "GRAY16" else np.uint8 ) self.pixel_size_xy_um = self.summary_metadata["PixelSize_um"] self.pixel_size_z_um = ( self.summary_metadata["z-step_um"] if "z-step_um" in self.summary_metadata else None ) self.image_width = res_level.reader_list[0].width self.image_height = res_level.reader_list[0].height self.overlap = ( np.array( [ self.summary_metadata["GridPixelOverlapY"], self.summary_metadata["GridPixelOverlapX"], ] ) if "GridPixelOverlapY" in self.summary_metadata else None ) c_z_t_p_tree = res_level.reader_tree # the c here refers to super channels, encompassing all non-tzp axes in addition to channels # map of axis names to values where data exists self.axes = { self._Z_AXIS: set(), self._TIME_AXIS: set(), self._POSITION_AXIS: set(), self._CHANNEL_AXIS: set(), } for c in c_z_t_p_tree.keys(): for z in c_z_t_p_tree[c]: self.axes[self._Z_AXIS].add(z) for t in c_z_t_p_tree[c][z]: self.axes[self._TIME_AXIS].add(t) for p in c_z_t_p_tree[c][z][t]: self.axes[self._POSITION_AXIS].add(p) if c not in self.axes["channel"]: metadata = self.res_levels[0].read_metadata( channel_index=c, z_index=z, t_index=t, pos_index=p ) current_axes = metadata["Axes"] non_zpt_axes = {} for axis in current_axes: if axis not in [ self._Z_AXIS, self._TIME_AXIS, self._POSITION_AXIS, ]: if axis not in self.axes: self.axes[axis] = set() self.axes[axis].add(current_axes[axis]) non_zpt_axes[axis] = current_axes[axis] self._channel_names[metadata["Channel"]] = non_zpt_axes[ self._CHANNEL_AXIS ] self._extra_axes_to_storage_channel[ frozenset(non_zpt_axes.items()) ] = c # remove axes with no variation single_axes = [axis for axis in self.axes if len(self.axes[axis]) == 1] for axis in single_axes: del self.axes[axis] if "position" in self.axes and "GridPixelOverlapX" in self.summary_metadata: # Make an n x 2 array with nan's where no positions actually exist self.row_col_array = np.ones((len(self.axes["position"]), 2)) * np.nan self.position_centers = np.ones((len(self.axes["position"]), 2)) * np.nan row_cols = [] for c_index in c_z_t_p_tree.keys(): for z_index in c_z_t_p_tree[c_index].keys(): for t_index in c_z_t_p_tree[c_index][z_index].keys(): p_indices = c_z_t_p_tree[c_index][z_index][t_index].keys() for p_index in p_indices: # in case position index doesn't start at 0, pos_index_index is index # into self.axes['position'] pos_index_index = list(self.axes["position"]).index(p_index) if not np.isnan(self.row_col_array[pos_index_index, 0]): # already figured this one out continue if not res_level.check_ifd( channel_index=c_index, z_index=z_index, t_index=t_index, pos_index=p_index, ): row_cols.append( np.array([np.nan, np.nan]) ) # this position is corrupted warnings.warn( "Corrupted image p: {} c: {} t: {} z: {}".format( p_index, c_index, t_index, z_index ) ) row_cols.append(np.array([np.nan, np.nan])) else: md = res_level.read_metadata( channel_index=c_index, pos_index=p_index, t_index=t_index, z_index=z_index, ) self.row_col_array[pos_index_index] = np.array( [md["GridRowIndex"], md["GridColumnIndex"]] ) self.position_centers[pos_index_index] = np.array( [ md["XPosition_um_Intended"], md["YPosition_um_Intended"], ] ) else: self.res_levels[int(np.log2(int(res_dir.split("x")[1])))] = res_level print("\rDataset opened") def as_array(self, stitched=False, verbose=False): """ Read all data image data as one big Dask array with last two axes as y, x and preceeding axes depending on data. The dask array is made up of memory-mapped numpy arrays, so the dataset does not need to be able to fit into RAM. If the data doesn't fully fill out the array (e.g. not every z-slice collected at every time point), zeros will be added automatically. To convert data into a numpy array, call np.asarray() on the returned result. However, doing so will bring the data into RAM, so it may be better to do this on only a slice of the array at a time. Parameters ---------- stitched : bool If true and tiles were acquired in a grid, lay out adjacent tiles next to one another (Default value = False) verbose : bool If True print updates on progress loading the image Returns ------- dataset : dask array """ if self._remote_storage is not None: raise Exception("Method not yet implemented for in progress acquisitions") self._empty_tile = ( np.zeros((self.image_height, self.image_width), self.dtype) if not self.rgb else np.zeros((self.image_height, self.image_width, 3), self.dtype) ) self._count = 1 total = np.prod([len(v) for v in self.axes.values()]) def recurse_axes(loop_axes, point_axes): if len(loop_axes.values()) == 0: if verbose: print("\rAdding data chunk {} of {}".format(self._count, total), end="") self._count += 1 if None not in point_axes.values() and self.has_image(**point_axes): return self.read_image(**point_axes, memmapped=True) else: # return np.zeros((self.image_height, self.image_width), self.dtype) return self._empty_tile else: # do position first because it makes stitching faster axis = ( "position" if "position" in loop_axes.keys() and stitched else list(loop_axes.keys())[0] ) remaining_axes = loop_axes.copy() del remaining_axes[axis] if axis == "position" and stitched: # Stitch tiles acquired in a grid self.half_overlap = self.overlap[0] // 2 # get spatial layout of position indices zero_min_row_col = self.row_col_array - np.nanmin(self.row_col_array, axis=0) row_col_mat = np.nan * np.ones( [ int(np.nanmax(zero_min_row_col[:, 0])) + 1, int(np.nanmax(zero_min_row_col[:, 1])) + 1, ] ) positions_indices = np.array(list(loop_axes["position"])) rows = zero_min_row_col[positions_indices][:, 0] cols = zero_min_row_col[positions_indices][:, 1] # mask in case some positions were corrupted mask = np.logical_not(np.isnan(rows)) row_col_mat[ rows[mask].astype(np.int), cols[mask].astype(np.int) ] = positions_indices[mask] blocks = [] for row in row_col_mat: blocks.append([]) for p_index in row: if verbose: print( "\rAdding data chunk {} of {}".format(self._count, total), end="", ) valed_axes = point_axes.copy() valed_axes[axis] = int(p_index) if not np.isnan(p_index) else None blocks[-1].append(da.stack(recurse_axes(remaining_axes, valed_axes))) if self.rgb: stitched_array = np.concatenate( [ np.concatenate(row, axis=len(blocks[0][0].shape) - 2) for row in blocks ], axis=len(blocks[0][0].shape) - 3, ) else: stitched_array = da.block(blocks) return stitched_array else: blocks = [] for val in loop_axes[axis]: valed_axes = point_axes.copy() valed_axes[axis] = val blocks.append(recurse_axes(remaining_axes, valed_axes)) return blocks blocks = recurse_axes(self.axes, {}) if verbose: print( " Stacking tiles" ) # extra space otherwise there is no space after the "Adding data chunk {} {}" array = da.stack(blocks) if verbose: print("\rDask array opened") return array def _convert_to_storage_axes(self, axes, channel_name=None): """Convert an abitrary set of axes to cztp axes as in the underlying storage Parameters ---------- axes channel_name """ if channel_name is not None: if channel_name not in self._channel_names.keys(): raise Exception("Channel name {} not found".format(channel_name)) axes[self._CHANNEL_AXIS] = self._channel_names[channel_name] if self._CHANNEL_AXIS not in axes: axes[self._CHANNEL_AXIS] = 0 z_index = axes[self._Z_AXIS] if self._Z_AXIS in axes else 0 t_index = axes[self._TIME_AXIS] if self._TIME_AXIS in axes else 0 p_index = axes[self._POSITION_AXIS] if self._POSITION_AXIS in axes else 0 non_zpt_axes = { key: axes[key] for key in axes.keys() if key not in [self._TIME_AXIS, self._POSITION_AXIS, self._Z_AXIS] } for axis in non_zpt_axes.keys(): if axis not in self.axes.keys() and axis != "channel": raise Exception("Unknown axis: {}".format(axis)) c_index = self._extra_axes_to_storage_channel[frozenset(non_zpt_axes.items())] return c_index, t_index, p_index, z_index def has_image( self, channel=None, z=None, time=None, position=None, channel_name=None, resolution_level=0, row=None, col=None, **kwargs ): """Check if this image is present in the dataset Parameters ---------- channel : int index of the channel, if applicable (Default value = None) z : int index of z slice, if applicable (Default value = None) time : int index of the time point, if applicable (Default value = None) position : int index of the XY position, if applicable (Default value = None) channel_name : str Name of the channel. Overrides channel index if supplied (Default value = None) row : int index of tile row for XY tiled datasets (Default value = None) col : int index of tile col for XY tiled datasets (Default value = None) resolution_level : 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) (Default value = 0) **kwargs Arbitrary keyword arguments Returns ------- bool : indicating whether the dataset has an image matching the specifications """ if channel is not None: kwargs["channel"] = channel if z is not None: kwargs["z"] = z if time is not None: kwargs["time"] = time if position is not None: kwargs["position"] = position if self._remote_storage is not None: axes = self._bridge.construct_java_object("java.util.HashMap") for key in kwargs.keys(): axes.put(key, kwargs[key]) if row is not None and col is not None: return self._remote_storage.has_tile_by_row_col(axes, resolution_level, row, col) else: return self._remote_storage.has_image(axes, resolution_level) if row is not None or col is not None: raise Exception("row col lookup not yet implmented for saved datasets") # self.row_col_array #TODO: find position index in here storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes( kwargs, channel_name=channel_name ) c_z_t_p_tree = self.res_levels[resolution_level].reader_tree if ( storage_c_index in c_z_t_p_tree and z_index in c_z_t_p_tree[storage_c_index] and t_index in c_z_t_p_tree[storage_c_index][z_index] and p_index in c_z_t_p_tree[storage_c_index][z_index][t_index] ): res_level = self.res_levels[resolution_level] return res_level.check_ifd( channel_index=storage_c_index, z_index=z_index, t_index=t_index, pos_index=p_index ) return False def read_image( self, channel=None, z=None, time=None, position=None, channel_name=None, read_metadata=False, resolution_level=0, row=None, col=None, memmapped=False, **kwargs ): """ Read image data as numpy array Parameters ---------- channel : int index of the channel, if applicable (Default value = None) z : int index of z slice, if applicable (Default value = None) time : int index of the time point, if applicable (Default value = None) position : int index of the XY position, if applicable (Default value = None) channel_name : Name of the channel. Overrides channel index if supplied (Default value = None) row : int index of tile row for XY tiled datasets (Default value = None) col : int index of tile col for XY tiled datasets (Default value = None) resolution_level : 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) (Default value = 0) read_metadata : bool (Default value = False) memmapped : bool (Default value = False) **kwargs : names and integer positions of any other axes Returns ------- image : numpy array or tuple image as a 2D numpy array, or tuple with image and image metadata as dict """ if channel is not None: kwargs["channel"] = channel if z is not None: kwargs["z"] = z if time is not None: kwargs["time"] = time if position is not None: kwargs["position"] = position if self._remote_storage is not None: if memmapped: raise Exception("Memory mapping not available for in progress acquisitions") axes = self._bridge.construct_java_object("java.util.HashMap") for key in kwargs.keys(): axes.put(key, kwargs[key]) if not self._remote_storage.has_image(axes, resolution_level): return None if row is not None and col is not None: tagged_image = self._remote_storage.get_tile_by_row_col( axes, resolution_level, row, col ) else: tagged_image = self._remote_storage.get_image(axes, resolution_level) if tagged_image is None: return None if resolution_level == 0: image = np.reshape( tagged_image.pix, newshape=[tagged_image.tags["Height"], tagged_image.tags["Width"]], ) if (self._tile_height is not None) and (self._tile_width is not None): # crop down to just the part that shows (i.e. no overlap) image = image[ (image.shape[0] - self._tile_height) // 2 : -(image.shape[0] - self._tile_height) // 2, (image.shape[1] - self._tile_width) // 2 : -(image.shape[1] - self._tile_width) // 2, ] else: image = np.reshape(tagged_image.pix, newshape=[self._tile_height, self._tile_width]) if read_metadata: return image, tagged_image.tags return image if row is not None or col is not None: raise Exception("row col lookup not yet implmented for saved datasets") # self.row_col_array #TODO: find position index in here storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes( kwargs, channel_name=channel_name ) res_level = self.res_levels[resolution_level] return res_level.read_image( storage_c_index, z_index, t_index, p_index, read_metadata, memmapped ) def read_first_image_metadata(self): """ Get the first image metadata in the dataset (according to position along axes). This is useful if you want to access the image metadata in a dataset sparse, nonzero azes Returns ------- metadata : dict """ cztp_tree = self.res_levels[0].reader_tree c = list(cztp_tree.keys())[0] z = list(cztp_tree[c].keys())[0] t = list(cztp_tree[c][z].keys())[0] p = list(cztp_tree[c][z][t].keys())[0] return self.res_levels[0].read_metadata(c, z, t, p) def read_metadata( self, channel=None, z=None, time=None, position=None, channel_name=None, row=None, col=None, resolution_level=0, **kwargs ): """ Read metadata only. Faster than using read_image to retrieve metadata Parameters ---------- channel : int index of the channel, if applicable (Default value = None) z : int index of z slice, if applicable (Default value = None) time : int index of the time point, if applicable (Default value = None) position : int index of the XY position, if applicable (Default value = None) channel_name : Name of the channel. Overrides channel index if supplied (Default value = None) row : int index of tile row for XY tiled datasets (Default value = None) col : int index of tile col for XY tiled datasets (Default value = None) resolution_level : 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) (Default value = 0) **kwargs : names and integer positions of any other axes Returns ------- metadata : dict """ if channel is not None: kwargs["channel"] = channel if z is not None: kwargs["z"] = z if time is not None: kwargs["time"] = time if position is not None: kwargs["position"] = position if self._remote_storage is not None: # read the tagged image because no funciton in Java API rn for metadata only return self.read_image( channel=channel, z=z, time=time, position=position, channel_name=channel_name, read_metadata=True, resolution_level=resolution_level, row=row, col=col, **kwargs )[1] storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes( kwargs, channel_name=channel_name ) res_level = self.res_levels[resolution_level] return res_level.read_metadata(storage_c_index, z_index, t_index, p_index) def close(self): if self._remote_storage is not None: # nothing to do, this is handled on the java side return for res_level in self.res_levels: res_level.close() def get_channel_names(self): if self._remote_storage is not None: raise Exception("Not implemented for in progress datasets") return self._channel_names.keys()
class Dataset: """ Class that opens a single NDTiffStorage dataset """ _POSITION_AXIS = 'position' _Z_AXIS = 'z' _TIME_AXIS = 'time' _CHANNEL_AXIS = 'channel' def __init__(self, dataset_path=None, full_res_only=True, remote_storage=None): if remote_storage is not None: #this dataset is a view of an active acquisiiton. The storage exists on the java side self._remote_storage = remote_storage self._bridge = Bridge() smd = self._remote_storage.get_summary_metadata() self._tile_width = smd['Width'] - smd['GridPixelOverlapX'] self._tile_height = smd['Height'] - smd['GridPixelOverlapY'] return else: self._remote_storage = None self.path = dataset_path res_dirs = [ dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI)) ] # map from downsample factor to datset self.res_levels = {} if 'Full resolution' not in res_dirs: raise Exception( 'Couldn\'t find full resolution directory. Is this the correct path to a dataset?' ) num_tiffs = 0 count = 0 for res_dir in res_dirs: for file in os.listdir(os.path.join(dataset_path, res_dir)): if file.endswith('.tif'): num_tiffs += 1 for res_dir in res_dirs: if full_res_only and res_dir != 'Full resolution': continue res_dir_path = os.path.join(dataset_path, res_dir) res_level = _ResolutionLevel(res_dir_path, count, num_tiffs) if res_dir == 'Full resolution': #TODO: might want to move this within the resolution level class to facilitate loading pyramids self.res_levels[0] = res_level # get summary metadata and index tree from full resolution image self.summary_metadata = res_level.reader_list[0].summary_md self._channel_names = {} #read them from image metadata self._extra_axes_to_storage_channel = {} # store some fields explicitly for easy access self.dtype = np.uint16 if self.summary_metadata[ 'PixelType'] == 'GRAY16' else np.uint8 self.pixel_size_xy_um = self.summary_metadata['PixelSize_um'] self.pixel_size_z_um = self.summary_metadata[ 'z-step_um'] if 'z-step_um' in self.summary_metadata else None self.image_width = res_level.reader_list[0].width self.image_height = res_level.reader_list[0].height self.overlap = np.array([ self.summary_metadata['GridPixelOverlapY'], self.summary_metadata['GridPixelOverlapX'] ]) if 'GridPixelOverlapY' in self.summary_metadata else None c_z_t_p_tree = res_level.reader_tree #the c here refers to super channels, encompassing all non-tzp axes in addition to channels # map of axis names to values where data exists self.axes = { self._Z_AXIS: set(), self._TIME_AXIS: set(), self._POSITION_AXIS: set(), self._CHANNEL_AXIS: set() } for c in c_z_t_p_tree.keys(): for z in c_z_t_p_tree[c]: self.axes[self._Z_AXIS].add(z) for t in c_z_t_p_tree[c][z]: self.axes[self._TIME_AXIS].add(t) for p in c_z_t_p_tree[c][z][t]: self.axes[self._POSITION_AXIS].add(p) if c not in self.axes['channel']: metadata = self.res_levels[ 0].read_metadata(channel_index=c, z_index=z, t_index=t, pos_index=p) current_axes = metadata['Axes'] non_zpt_axes = {} for axis in current_axes: if axis not in [ self._Z_AXIS, self._TIME_AXIS, self._POSITION_AXIS ]: if axis not in self.axes: self.axes[axis] = set() self.axes[axis].add( current_axes[axis]) non_zpt_axes[axis] = current_axes[ axis] self._channel_names[ metadata['Channel']] = non_zpt_axes[ self._CHANNEL_AXIS] self._extra_axes_to_storage_channel[ frozenset(non_zpt_axes.items())] = c #remove axes with no variation single_axes = [ axis for axis in self.axes if len(self.axes[axis]) == 1 ] for axis in single_axes: del self.axes[axis] if 'position' in self.axes and 'GridPixelOverlapX' in self.summary_metadata: #Make an n x 2 array with nan's where no positions actually exist row_cols = [] positions_checked = [] for c_index in c_z_t_p_tree.keys(): for z_index in c_z_t_p_tree[c_index].keys(): for t_index in c_z_t_p_tree[c_index][z_index].keys( ): p_indices = c_z_t_p_tree[c_index][z_index][ t_index].keys() for p_index in range(max(p_indices) + 1): if p_index in positions_checked: continue if p_index not in p_indices: row_cols.append( np.array([np.nan, np.nan])) elif not res_level.check_ifd( channel_index=c_index, z_index=z_index, t_index=t_index, pos_index=p_index): row_cols.append( np.array([ np.nan, np.nan ])) #this position is corrupted warnings.warn( 'Corrupted image p: {} c: {} t: {} z: {}' .format(p_index, c_index, t_index, z_index)) row_cols.append( np.array([np.nan, np.nan])) else: md = res_level.read_metadata( channel_index=c_index, pos_index=p_index, t_index=t_index, z_index=z_index) row_cols.append( np.array([ md['GridRowIndex'], md['GridColumnIndex'] ])) positions_checked.append(p_index) self.row_col_array = np.stack(row_cols) else: self.res_levels[int(np.log2(int( res_dir.split('x')[1])))] = res_level print('\rDataset opened') def as_array(self, stitched=False): """ Read all data image data as one big Dask array with last two axes as y, x and preceeding axes depending on data. The dask array is made up of memory-mapped numpy arrays, so the dataset does not need to be able to fit into RAM. If the data doesn't fully fill out the array (e.g. not every z-slice collected at every time point), zeros will be added automatically. To convert data into a numpy array, call np.asarray() on the returned result. However, doing so will bring the data into RAM, so it may be better to do this on only a slice of the array at a time. :param stitched: If true and tiles were acquired in a grid, lay out adjacent tiles next to one another :type stitched: boolean :return: """ if self._remote_storage is not None: raise Exception( 'Method not yet implemented for in progress acquisitions') self._empty_tile = np.zeros((self.image_height, self.image_width), self.dtype) self._count = 1 total = np.prod([len(v) for v in self.axes.values()]) def recurse_axes(loop_axes, point_axes): if len(loop_axes.values()) == 0: print('\rAdding data chunk {} of {}'.format( self._count, total), end='') self._count += 1 if None not in point_axes.values() and self.has_image( **point_axes): return self.read_image(**point_axes, memmapped=True) else: # return np.zeros((self.image_height, self.image_width), self.dtype) return self._empty_tile else: #do position first because it makes stitching faster axis = 'position' if 'position' in loop_axes.keys( ) and stitched else list(loop_axes.keys())[0] remaining_axes = loop_axes.copy() del remaining_axes[axis] if axis == 'position' and stitched: #Stitch tiles acquired in a grid self.half_overlap = self.overlap[0] // 2 # get spatial layout of position indices zero_min_row_col = (self.row_col_array - np.nanmin(self.row_col_array, axis=0)) row_col_mat = np.nan * np.ones([ int(np.nanmax(zero_min_row_col[:, 0])) + 1, int(np.nanmax(zero_min_row_col[:, 1])) + 1 ]) positions_indices = np.array(list(loop_axes['position'])) rows = zero_min_row_col[positions_indices][:, 0] cols = zero_min_row_col[positions_indices][:, 1] # mask in case some positions were corrupted mask = np.logical_not(np.isnan(rows)) row_col_mat[ rows[mask].astype(np.int), cols[mask].astype(np.int)] = positions_indices[mask] blocks = [] for row in row_col_mat: blocks.append([]) for p_index in row: print('\rAdding data chunk {} of {}'.format( self._count, total), end='') valed_axes = point_axes.copy() valed_axes[axis] = int( p_index) if not np.isnan(p_index) else None blocks[-1].append( da.stack( recurse_axes(remaining_axes, valed_axes))) stitched_array = da.block(blocks) return stitched_array else: blocks = [] for val in loop_axes[axis]: valed_axes = point_axes.copy() valed_axes[axis] = val blocks.append(recurse_axes(remaining_axes, valed_axes)) return blocks blocks = recurse_axes(self.axes, {}) print('Stacking tiles') array = da.stack(blocks) print('\rDask array opened') return array def _convert_to_storage_axes(self, axes, channel_name=None): """ Convert an abitrary set of axes to cztp axes as in the underlying storage :param axes: :return: """ if channel_name is not None: if channel_name not in self._channel_names.keys(): raise Exception( 'Channel name {} not found'.format(channel_name)) axes[self._CHANNEL_AXIS] = self._channel_names[channel_name] if self._CHANNEL_AXIS not in axes: axes[self._CHANNEL_AXIS] = 0 z_index = axes[self._Z_AXIS] if self._Z_AXIS in axes else 0 t_index = axes[self._TIME_AXIS] if self._TIME_AXIS in axes else 0 p_index = axes[ self._POSITION_AXIS] if self._POSITION_AXIS in axes else 0 non_zpt_axes = { key: axes[key] for key in axes.keys() if key not in [self._TIME_AXIS, self._POSITION_AXIS, self._Z_AXIS] } for axis in non_zpt_axes.keys(): if axis not in self.axes.keys() and axis != 'channel': raise Exception('Unknown axis: {}'.format(axis)) c_index = self._extra_axes_to_storage_channel[frozenset( non_zpt_axes.items())] return c_index, t_index, p_index, z_index def has_image(self, channel=None, z=None, time=None, position=None, channel_name=None, resolution_level=0, row=None, col=None, **kwargs): """ Check if this image is present in the dataset :param channel: index of the channel, if applicable :type channel: int :param z: index of z slice, if applicable :type z: int :param time: index of the time point, if applicable :type time: int :param position: index of the XY position, if applicable :type position: int :param channel_name: Name of the channel. Overrides channel index if supplied :type channel_name: str :param row: index of tile row for XY tiled datasets :type row: int :param col: index of tile col for XY tiled datasets :type col: int :param resolution_level: 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) :param kwargs: names and integer positions of any other axes :return: boolean indicating whether image present """ if channel is not None: kwargs['channel'] = channel if z is not None: kwargs['z'] = z if time is not None: kwargs['time'] = time if position is not None: kwargs['position'] = position if self._remote_storage is not None: axes = self._bridge.construct_java_object('java.util.HashMap') for key in kwargs.keys(): axes.put(key, kwargs[key]) if row is not None and col is not None: return self._remote_storage.has_tile_by_row_col( axes, resolution_level, row, col) else: return self._remote_storage.has_image(axes, resolution_level) if row is not None or col is not None: raise Exception( 'row col lookup not yet implmented for saved datasets') # self.row_col_array #TODO: find position index in here storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes( kwargs, channel_name=channel_name) c_z_t_p_tree = self.res_levels[resolution_level].reader_tree if storage_c_index in c_z_t_p_tree and z_index in c_z_t_p_tree[storage_c_index] and t_index in \ c_z_t_p_tree[storage_c_index][z_index] and p_index in c_z_t_p_tree[storage_c_index][z_index][t_index]: res_level = self.res_levels[resolution_level] return res_level.check_ifd(channel_index=storage_c_index, z_index=z_index, t_index=t_index, pos_index=p_index) return False def read_image(self, channel=None, z=None, time=None, position=None, channel_name=None, read_metadata=False, resolution_level=0, row=None, col=None, memmapped=False, **kwargs): """ Read image data as numpy array :param channel: index of the channel, if applicable :type channel: int :param z: index of z slice, if applicable :type z: int :param time: index of the time point, if applicable :type time: int :param position: index of the XY position, if applicable :type position: int :param channel_name: Name of the channel. Overrides channel index if supplied :param row: index of tile row for XY tiled datasets :type row: int :param col: index of tile col for XY tiled datasets :type col: int :param resolution_level: 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) :param kwargs: names and integer positions of any other axes :return: image as 2D numpy array, or tuple with image and image metadata as dict """ if channel is not None: kwargs['channel'] = channel if z is not None: kwargs['z'] = z if time is not None: kwargs['time'] = time if position is not None: kwargs['position'] = position if self._remote_storage is not None: if memmapped: raise Exception( 'Memory mapping not available for in progress acquisitions' ) axes = self._bridge.construct_java_object('java.util.HashMap') for key in kwargs.keys(): axes.put(key, kwargs[key]) if not self._remote_storage.has_image(axes, resolution_level): return None if row is not None and col is not None: tagged_image = self._remote_storage.get_tile_by_row_col( axes, resolution_level, row, col) else: tagged_image = self._remote_storage.get_image( axes, resolution_level) if tagged_image is None: return None if resolution_level == 0: image = np.reshape(tagged_image.pix, newshape=[ tagged_image.tags['Height'], tagged_image.tags['Width'] ]) #crop down to just the part that shows (i.e. no overlap) image = image[(image.shape[0] - self._tile_height) // 2:-(image.shape[0] - self._tile_height) // 2, (image.shape[0] - self._tile_width) // 2:-(image.shape[0] - self._tile_width) // 2] else: image = np.reshape( tagged_image.pix, newshape=[self._tile_height, self._tile_width]) if read_metadata: return image, tagged_image.tags return image if row is not None or col is not None: raise Exception( 'row col lookup not yet implmented for saved datasets') # self.row_col_array #TODO: find position index in here storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes( kwargs, channel_name=channel_name) res_level = self.res_levels[resolution_level] return res_level.read_image(storage_c_index, z_index, t_index, p_index, read_metadata, memmapped) def read_metadata(self, channel=None, z=None, time=None, position=None, channel_name=None, row=None, col=None, resolution_level=0, **kwargs): """ Read metadata only. Faster than using read_image to retireve metadata :param channel: index of the channel, if applicable :type channel: int :param z: index of z slice, if applicable :type z: int :param time: index of the time point, if applicable :type time: int :param position: index of the XY position, if applicable :type position: int :param channel_name: Name of the channel. Overrides channel index if supplied :param row: index of tile row for XY tiled datasets :type row: int :param col: index of tile col for XY tiled datasets :type col: int :param resolution_level: 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) :param kwargs: names and integer positions of any other axes :return: image metadata as dict """ if channel is not None: kwargs['channel'] = channel if z is not None: kwargs['z'] = z if time is not None: kwargs['time'] = time if position is not None: kwargs['position'] = position if self._remote_storage is not None: #read the tagged image because no funciton in Java API rn for metadata only return self.read_image(channel=channel, z=z, time=time, position=position, channel_name=channel_name, read_metadata=True, resolution_level=resolution_level, row=row, col=col, **kwargs)[1] storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes( kwargs, channel_name=channel_name) res_level = self.res_levels[resolution_level] return res_level.read_metadata(storage_c_index, z_index, t_index, p_index) def close(self): if self._remote_storage is not None: #nothing to do, this is handled on the java side return for res_level in self.res_levels: res_level.close() def get_channel_names(self): if self._remote_storage is not None: raise Exception('Not implemented for in progress datasets') return self._channel_names.keys()
class Acquisition(object): """ """ def __init__( self, directory=None, name=None, image_process_fn=None, event_generation_hook_fn=None, pre_hardware_hook_fn=None, post_hardware_hook_fn=None, post_camera_hook_fn=None, show_display=True, tile_overlap=None, max_multi_res_index=None, magellan_acq_index=None, magellan_explore=False, process=False, debug=False, ): """ Parameters ---------- directory : str saving directory for this acquisition. Required unless an image process function will be implemented that diverts images from saving name : str Saving name for the acquisition. Required unless an image process function will be implemented that diverts images from saving image_process_fn : Callable image processing function that will be called on each image that gets acquired. Can either take two arguments (image, metadata) where image is a numpy array and metadata is a dict containing the corresponding iamge metadata. Or a 4 argument version is accepted, which accepts (image, metadata, bridge, queue), where bridge and queue are an instance of the pycromanager.acquire.Bridge object for the purposes of interacting with arbitrary code on the Java side (such as the micro-manager core), and queue is a Queue objects that holds upcomning acquisition events. Both version must either return event_generation_hook_fn : Callable hook function that will as soon as acquisition events are generated (before hardware sequencing optimization in the acquisition engine. This is useful if one wants to modify acquisition events that they didn't generate (e.g. those generated by a GUI application). Accepts either one argument (the current acquisition event) or three arguments (current event, bridge, event Queue) pre_hardware_hook_fn : Callable hook function that will be run just before the hardware is updated before acquiring a new image. In the case of hardware sequencing, it will be run just before a sequence of instructions are dispatched to the hardware. Accepts either one argument (the current acquisition event) or three arguments (current event, bridge, event Queue) post_hardware_hook_fn : Callable hook function that will be run just before the hardware is updated before acquiring a new image. In the case of hardware sequencing, it will be run just after a sequence of instructions are dispatched to the hardware, but before the camera sequence has been started. Accepts either one argument (the current acquisition event) or three arguments (current event, bridge, event Queue) post_camera_hook_fn : Callable hook function that will be run just after the camera has been triggered to snapImage or startSequence. A common use case for this hook is when one want to send TTL triggers to the camera from an external timing device that synchronizes with other hardware. Accepts either one argument (the current acquisition event) or three arguments (current event, bridge, event Queue) tile_overlap : int or tuple of int If given, XY tiles will be laid out in a grid and multi-resolution saving will be actived. Argument can be a two element tuple describing the pixel overlaps between adjacent tiles. i.e. (pixel_overlap_x, pixel_overlap_y), or an integer to use the same overlap for both. For these features to work, the current hardware configuration must have a valid affine transform between camera coordinates and XY stage coordinates max_multi_res_index : int Maximum index to downsample to in multi-res pyramid mode (which is only active if a value for "tile_overlap" is passed in, or if running a Micro-Magellan acquisition). 0 is no downsampling, 1 is downsampled up to 2x, 2 is downsampled up to 4x, etc. If not provided, it will be dynamically calculated and updated from data show_display : bool show the image viewer window magellan_acq_index : int run this acquisition using the settings specified at this position in the main GUI of micro-magellan (micro-manager plugin). This index starts at 0 magellan_explore : bool Run a Micro-magellan explore acquisition process : bool Use multiprocessing instead of multithreading for acquisition hooks and image processors. This can be used to speed up CPU-bounded processing by eliminating bottlenecks caused by Python's Global Interpreter Lock, but also creates complications on Windows-based systems debug : bool whether to print debug messages """ self.bridge = Bridge(debug=debug) self._debug = debug self._dataset = None if directory is not None: # Expend ~ in path directory = os.path.expanduser(directory) # If path is relative, retain knowledge of the current working directory directory = os.path.abspath(directory) if magellan_acq_index is not None: magellan_api = self.bridge.get_magellan() self._remote_acq = magellan_api.create_acquisition( magellan_acq_index) self._event_queue = None elif magellan_explore: magellan_api = self.bridge.get_magellan() self._remote_acq = magellan_api.create_explore_acquisition() self._event_queue = None else: # Create thread safe queue for events so they can be passed from multiple processes self._event_queue = multiprocessing.Queue( ) if process else queue.Queue() core = self.bridge.get_core() acq_factory = self.bridge.construct_java_object( "org.micromanager.remote.RemoteAcquisitionFactory", args=[core]) show_viewer = show_display and (directory is not None and name is not None) if tile_overlap is None: # argument placeholders, these wont actually be used x_overlap = 0 y_overlap = 0 else: if type(tile_overlap) is tuple: x_overlap, y_overlap = tile_overlap else: x_overlap = tile_overlap y_overlap = tile_overlap self._remote_acq = acq_factory.create_acquisition( directory, name, show_viewer, tile_overlap is not None, x_overlap, y_overlap, max_multi_res_index if max_multi_res_index is not None else -1, ) storage = self._remote_acq.get_data_sink() if storage is not None: self.disk_location = storage.get_disk_location() if image_process_fn is not None: processor = self.bridge.construct_java_object( "org.micromanager.remote.RemoteImageProcessor") self._remote_acq.add_image_processor(processor) self._start_processor(processor, image_process_fn, self._event_queue, process=process) if event_generation_hook_fn is not None: hook = self.bridge.construct_java_object( "org.micromanager.remote.RemoteAcqHook", args=[self._remote_acq]) self._start_hook(hook, event_generation_hook_fn, self._event_queue, process=process) self._remote_acq.add_hook(hook, self._remote_acq.EVENT_GENERATION_HOOK) if pre_hardware_hook_fn is not None: hook = self.bridge.construct_java_object( "org.micromanager.remote.RemoteAcqHook", args=[self._remote_acq]) self._start_hook(hook, pre_hardware_hook_fn, self._event_queue, process=process) self._remote_acq.add_hook(hook, self._remote_acq.BEFORE_HARDWARE_HOOK) if post_hardware_hook_fn is not None: hook = self.bridge.construct_java_object( "org.micromanager.remote.RemoteAcqHook", args=[self._remote_acq]) self._start_hook(hook, post_hardware_hook_fn, self._event_queue, process=process) self._remote_acq.add_hook(hook, self._remote_acq.AFTER_HARDWARE_HOOK) if post_camera_hook_fn is not None: hook = self.bridge.construct_java_object( "org.micromanager.remote.RemoteAcqHook", args=[self._remote_acq]) self._start_hook(hook, post_camera_hook_fn, self._event_queue, process=process) self._remote_acq.add_hook(hook, self._remote_acq.AFTER_CAMERA_HOOK) self._remote_acq.start() if magellan_acq_index is None and not magellan_explore: self.event_port = self._remote_acq.get_event_port() self.event_process = threading.Thread( target=_event_sending_fn, args=(self.event_port, self._event_queue, self._debug), name="Event sending", ) self.event_process.start() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self._event_queue is not None: # magellan acquisitions dont have this # this should shut down storage and viewer as apporpriate self._event_queue.put(None) # now wait on it to finish self.await_completion() def get_disk_location(self): """ Return the path where the dataset is on disk """ return self._remote_acq.get_storage().get_disk_location() def get_dataset(self): """ """ if self._dataset is None: self._dataset = Dataset( remote_storage=self._remote_acq.get_storage()) return self._dataset def await_completion(self): """Wait for acquisition to finish and resources to be cleaned up""" while not self._remote_acq.is_finished(): time.sleep(0.1) def acquire(self, events, keep_shutter_open=False): """Submit an event or a list of events for acquisition. Optimizations (i.e. taking advantage of hardware synchronization, where available), will take place across this list of events, but not over multiple calls of this method. A single event is a python dictionary with a specific structure Parameters ---------- events keep_shutter_open : (Default value = False) Returns ------- """ if keep_shutter_open and isinstance(events, list): for e in events: e["keep_shutter_open"] = True events.append({"keep_shutter_open": False }) # return to autoshutter, dont acquire an image elif keep_shutter_open and isinstance(events, dict): events["keep_shutter_open"] = True events = [ events, { "keep_shutter_open": False }, ] # return to autoshutter, dont acquire an image self._event_queue.put(events) def _start_hook(self, remote_hook, remote_hook_fn, event_queue, process): """ Parameters ---------- remote_hook : remote_hook_fn : event_queue : process : Returns ------- """ hook_connected_evt = multiprocessing.Event( ) if process else threading.Event() pull_port = remote_hook.get_pull_port() push_port = remote_hook.get_push_port() hook_thread = (multiprocessing.Process if process else threading.Thread)( target=_acq_hook_startup_fn, name="AcquisitionHook", args=( pull_port, push_port, hook_connected_evt, event_queue, remote_hook_fn, self._debug, ), ) # if process else threading.Thread(target=_acq_hook_fn, args=(), name='AcquisitionHook') hook_thread.start() hook_connected_evt.wait() # wait for push/pull sockets to connect def _start_processor(self, processor, process_fn, event_queue, process): """ Parameters ---------- processor : process_fn : event_queue : process : Returns ------- """ # this must start first processor.start_pull() sockets_connected_evt = multiprocessing.Event( ) if process else threading.Event() pull_port = processor.get_pull_port() push_port = processor.get_push_port() self.processor_thread = (multiprocessing.Process if process else threading.Thread)( target=_processor_startup_fn, args=( pull_port, push_port, sockets_connected_evt, process_fn, event_queue, self._debug, ), name="ImageProcessor", ) self.processor_thread.start() sockets_connected_evt.wait() # wait for push/pull sockets to connect processor.start_push()
class Acquisition(object): def __init__(self, directory=None, name=None, image_process_fn=None, pre_hardware_hook_fn=None, post_hardware_hook_fn=None, tile_overlap=None, magellan_acq_index=None, process=True, debug=False): """ :param directory: saving directory for this acquisition. Required unless an image process function will be implemented that diverts images from saving :type directory: str :param name: Saving name for the acquisition. Required unless an image process function will be implemented that diverts images from saving :type name: str :param image_process_fn: image processing function that will be called on each image that gets acquired. Can either take two arguments (image, metadata) where image is a numpy array and metadata is a dict containing the corresponding iamge metadata. Or a 4 argument version is accepted, which accepts (image, metadata, bridge, queue), where bridge and queue are an instance of the pycromanager.acquire.Bridge object for the purposes of interacting with arbitrary code on the Java side (such as the micro-manager core), and queue is a Queue objects that holds upcomning acquisition events. Both version must either return :param pre_hardware_hook_fn: hook function that will be run just before the hardware is updated before acquiring a new image. Accepts either one argument (the current acquisition event) or three arguments (current event, bridge, event Queue) :param post_hardware_hook_fn: hook function that will be run just before the hardware is updated before acquiring a new image. Accepts either one argument (the current acquisition event) or three arguments (current event, bridge, event Queue) :param tile_overlap: If given, XY tiles will be laid out in a grid and multi-resolution saving will be actived. Argument can be a two element tuple describing the pixel overlaps between adjacent tiles. i.e. (pixel_overlap_x, pixel_overlap_y), or an integer to use the same overlap for both. For these features to work, the current hardware configuration must have a valid affine transform between camera coordinates and XY stage coordinates :type tile_overlap: tuple, int :param magellan_acq_index: run this acquisition using the settings specified at this position in the main GUI of micro-magellan (micro-manager plugin). This index starts at 0 :type magellan_acq_index: int :param process: (Experimental) use multiprocessing instead of multithreading for acquisition hooks and image processors :type process: boolean :param debug: print debugging stuff :type debug: boolean """ self.bridge = Bridge(debug=debug) self._debug = debug self._dataset = None if directory is not None: # Expend ~ in path directory = os.path.expanduser(directory) # If path is relative, retain knowledge of the current working directory directory = os.path.abspath(directory) if magellan_acq_index is not None: magellan_api = self.bridge.get_magellan() self._remote_acq = magellan_api.create_acquisition( magellan_acq_index) self._event_queue = None else: # Create thread safe queue for events so they can be passed from multiple processes self._event_queue = multiprocessing.Queue() core = self.bridge.get_core() acq_factory = self.bridge.construct_java_object( 'org.micromanager.remote.RemoteAcquisitionFactory', args=[core]) #TODO: could add hiding viewer as an option show_viewer = directory is not None and name is not None if tile_overlap is None: #argument placeholders, these wont actually be used x_overlap = 0 y_overlap = 0 else: if type(tile_overlap) is tuple: x_overlap, y_overlap = tile_overlap else: x_overlap = tile_overlap y_overlap = tile_overlap self._remote_acq = acq_factory.create_acquisition( directory, name, show_viewer, tile_overlap is not None, x_overlap, y_overlap) if image_process_fn is not None: processor = self.bridge.construct_java_object( 'org.micromanager.remote.RemoteImageProcessor') self._remote_acq.add_image_processor(processor) self._start_processor(processor, image_process_fn, self._event_queue, process=process) if pre_hardware_hook_fn is not None: hook = self.bridge.construct_java_object( 'org.micromanager.remote.RemoteAcqHook') self._start_hook(hook, pre_hardware_hook_fn, self._event_queue, process=process) self._remote_acq.add_hook(hook, self._remote_acq.BEFORE_HARDWARE_HOOK, args=[self._remote_acq]) if post_hardware_hook_fn is not None: hook = self.bridge.construct_java_object( 'org.micromanager.remote.RemoteAcqHook', args=[self._remote_acq]) self._start_hook(hook, post_hardware_hook_fn, self._event_queue, process=process) self._remote_acq.add_hook(hook, self._remote_acq.AFTER_HARDWARE_HOOK) self._remote_acq.start() if magellan_acq_index is None: self.event_port = self._remote_acq.get_event_port() self.event_process = multiprocessing.Process( target=_event_sending_fn, args=(self.event_port, self._event_queue, self._debug), name='Event sending') # if multiprocessing else threading.Thread(target=event_sending_fn, args=(), name='Event sending') self.event_process.start() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self._event_queue is not None: #magellan acquisitions dont have this # this should shut down storage and viewer as apporpriate self._event_queue.put(None) #now wait on it to finish self.await_completion() def get_dataset(self): """ Return a :class:`~pycromanager.data.Dataset` object that has access to the underlying pixels :return: :class:`~pycromanager.data.Dataset` corresponding to this acquisition """ if self._dataset is None: self._dataset = Dataset( remote_storage=self._remote_acq.get_storage()) return self._dataset def await_completion(self): """ Wait for acquisition to finish and resources to be cleaned up """ while (not self._remote_acq.is_finished()): time.sleep(0.1) def acquire(self, events): """ Submit an event or a list of events for acquisition. Optimizations (i.e. taking advantage of hardware synchronization, where available), will take place across this list of events, but not over multiple calls of this method. A single event is a python dictionary with a specific structure :param events: single event (i.e. a dictionary) or a list of events """ self._event_queue.put(events) def _start_hook(self, remote_hook, remote_hook_fn, event_queue, process): hook_connected_evt = multiprocessing.Event( ) if process else threading.Event() pull_port = remote_hook.get_pull_port() push_port = remote_hook.get_push_port() hook_thread = multiprocessing.Process( target=_acq_hook_startup_fn, name='AcquisitionHook', args=(pull_port, push_port, hook_connected_evt, event_queue, remote_hook_fn, self._debug)) # if process else threading.Thread(target=_acq_hook_fn, args=(), name='AcquisitionHook') hook_thread.start() hook_connected_evt.wait() # wait for push/pull sockets to connect def _start_processor(self, processor, process_fn, event_queue, process): # this must start first processor.start_pull() sockets_connected_evt = multiprocessing.Event( ) if process else threading.Event() pull_port = processor.get_pull_port() push_port = processor.get_push_port() self.processor_thread = multiprocessing.Process( target=_processor_startup_fn, args=(pull_port, push_port, sockets_connected_evt, process_fn, event_queue, self._debug), name='ImageProcessor') # if multiprocessing else threading.Thread(target=other_thread_fn, args=(), name='ImageProcessor') self.processor_thread.start() sockets_connected_evt.wait() # wait for push/pull sockets to connect processor.start_push()
class Dataset: """Class that opens a single NDTiffStorage dataset""" _POSITION_AXIS = "position" _ROW_AXIS = "roq" _COLUMN_AXIS = "column" _Z_AXIS = "z" _TIME_AXIS = "time" _CHANNEL_AXIS = "channel" def __new__(cls, dataset_path=None, full_res_only=True, remote_storage=None): if dataset_path is None: return super(Dataset, cls).__new__(Dataset) # Search for Full resolution dir, check for index res_dirs = [ dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI)) ] if "Full resolution" not in res_dirs: raise Exception( "Couldn't find full resolution directory. Is this the correct path to a dataset?" ) fullres_path = ( dataset_path + ("" if dataset_path[-1] == os.sep else os.sep) + "Full resolution" ) if "NDTiff.index" in os.listdir(fullres_path): return super(Dataset, cls).__new__(Dataset) else: obj = Legacy_NDTiff_Dataset.__new__(Legacy_NDTiff_Dataset) obj.__init__(dataset_path, full_res_only, remote_storage) return obj def __init__(self, dataset_path=None, full_res_only=True, remote_storage=None): self._tile_width = None self._tile_height = None if remote_storage is not None: # this dataset is a view of an active acquisiiton. The storage exists on the java side self._remote_storage = remote_storage self._bridge = Bridge() smd = self._remote_storage.get_summary_metadata() if "GridPixelOverlapX" in smd.keys(): self._tile_width = smd["Width"] - smd["GridPixelOverlapX"] self._tile_height = smd["Height"] - smd["GridPixelOverlapY"] return else: self._remote_storage = None self.path = dataset_path res_dirs = [ dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI)) ] # map from downsample factor to datset self.res_levels = {} if "Full resolution" not in res_dirs: raise Exception( "Couldn't find full resolution directory. Is this the correct path to a dataset?" ) num_tiffs = 0 count = 0 for res_dir in res_dirs: for file in os.listdir(os.path.join(dataset_path, res_dir)): if file.endswith(".tif"): num_tiffs += 1 for res_dir in res_dirs: if full_res_only and res_dir != "Full resolution": continue res_dir_path = os.path.join(dataset_path, res_dir) res_level = _ResolutionLevel(res_dir_path, count, num_tiffs) if res_dir == "Full resolution": self.res_levels[0] = res_level # get summary metadata and index tree from full resolution image self.summary_metadata = res_level.summary_metadata self.overlap = ( np.array( [ self.summary_metadata["GridPixelOverlapY"], self.summary_metadata["GridPixelOverlapX"], ] ) if "GridPixelOverlapY" in self.summary_metadata else None ) self.axes = {} for axes_combo in res_level.index.keys(): for axis, position in axes_combo: if axis not in self.axes.keys(): self.axes[axis] = set() self.axes[axis].add(position) # figure out the mapping of channel name to position by reading image metadata print("\rReading channel names...", end="") if self._CHANNEL_AXIS in self.axes.keys(): self._channel_names = {} for key in res_level.index.keys(): axes = {axis: position for axis, position in key} if ( self._CHANNEL_AXIS in axes.keys() and axes[self._CHANNEL_AXIS] not in self._channel_names.values() ): channel_name = res_level.read_metadata(axes)["Channel"] self._channel_names[channel_name] = axes[self._CHANNEL_AXIS] if len(self._channel_names.values()) == len(self.axes[self._CHANNEL_AXIS]): break print("\rFinished reading channel names", end="") # remove axes with no variation single_axes = [axis for axis in self.axes if len(self.axes[axis]) == 1] for axis in single_axes: del self.axes[axis] # If the dataset uses XY stitching, map out the row and col indices if ( "TiledImageStorage" in self.summary_metadata and self.summary_metadata["TiledImageStorage"] ): # Make an n x 2 array with nan's where no positions actually exist pass else: self.res_levels[int(np.log2(int(res_dir.split("x")[1])))] = res_level # get information about image width and height, assuming that they are consistent for whole dataset # (which isn't strictly neccesary) first_index = list(self.res_levels[0].index.values())[0] if first_index["pixel_type"] == _MultipageTiffReader.EIGHT_BIT_RGB: self.bytes_per_pixel = 3 self.dtype = np.uint8 elif first_index["pixel_type"] == _MultipageTiffReader.EIGHT_BIT: self.bytes_per_pixel = 1 self.dtype = np.uint8 elif first_index["pixel_type"] == _MultipageTiffReader.SIXTEEN_BIT: self.bytes_per_pixel = 2 self.dtype = np.uint16 self.image_width = first_index["image_width"] self.image_height = first_index["image_height"] if "GridPixelOverlapX" in self.summary_metadata: self._tile_width = self.image_width - self.summary_metadata["GridPixelOverlapX"] self._tile_height = self.image_height - self.summary_metadata["GridPixelOverlapY"] print("\rDataset opened ") def as_array(self, stitched=False, verbose=True): """ Read all data image data as one big Dask array with last two axes as y, x and preceeding axes depending on data. The dask array is made up of memory-mapped numpy arrays, so the dataset does not need to be able to fit into RAM. If the data doesn't fully fill out the array (e.g. not every z-slice collected at every time point), zeros will be added automatically. To convert data into a numpy array, call np.asarray() on the returned result. However, doing so will bring the data into RAM, so it may be better to do this on only a slice of the array at a time. Parameters ---------- stitched : bool If true and tiles were acquired in a grid, lay out adjacent tiles next to one another (Default value = False) verbose : bool If True print updates on progress loading the image Returns ------- dataset : dask array """ if self._remote_storage is not None: raise Exception("Method not yet implemented for in progress acquisitions") w = self.image_height if not stitched else self._tile_width h = self.image_height if not stitched else self._tile_height self._empty_tile = ( np.zeros((h, w), self.dtype) if self.bytes_per_pixel != 3 else np.zeros((h, w, 3), self.dtype) ) self._count = 1 total = np.prod([len(v) for v in self.axes.values()]) def recurse_axes(loop_axes, point_axes): if len(loop_axes.values()) == 0: if verbose: print("\rAdding data chunk {} of {}".format(self._count, total), end="") self._count += 1 if None not in point_axes.values() and self.has_image(**point_axes): if stitched: img = self.read_image(**point_axes, memmapped=True) if self.half_overlap[0] != 0: img = img[ self.half_overlap[0] : -self.half_overlap[0], self.half_overlap[1] : -self.half_overlap[1], ] return img else: return self.read_image(**point_axes, memmapped=True) else: # return np.zeros((self.image_height, self.image_width), self.dtype) return self._empty_tile else: # do position first because it makes stitching faster axis = ( "position" if "position" in loop_axes.keys() and stitched else list(loop_axes.keys())[0] ) remaining_axes = loop_axes.copy() del remaining_axes[axis] if axis == "position" and stitched: # Stitch tiles acquired in a grid self.half_overlap = (self.overlap[0] // 2, self.overlap[1] // 2) # get spatial layout of position indices zero_min_row_col = self.row_col_array - np.nanmin(self.row_col_array, axis=0) row_col_mat = np.nan * np.ones( [ int(np.nanmax(zero_min_row_col[:, 0])) + 1, int(np.nanmax(zero_min_row_col[:, 1])) + 1, ] ) positions_indices = np.array(list(loop_axes["position"])) rows = zero_min_row_col[positions_indices][:, 0] cols = zero_min_row_col[positions_indices][:, 1] # mask in case some positions were corrupted mask = np.logical_not(np.isnan(rows)) row_col_mat[ rows[mask].astype(np.int), cols[mask].astype(np.int) ] = positions_indices[mask] blocks = [] for row in row_col_mat: blocks.append([]) for p_index in row: if verbose: print( "\rAdding data chunk {} of {}".format(self._count, total), end="", ) valed_axes = point_axes.copy() valed_axes[axis] = int(p_index) if not np.isnan(p_index) else None blocks[-1].append(da.stack(recurse_axes(remaining_axes, valed_axes))) if self.rgb: stitched_array = np.concatenate( [ np.concatenate(row, axis=len(blocks[0][0].shape) - 2) for row in blocks ], axis=len(blocks[0][0].shape) - 3, ) else: stitched_array = da.block(blocks) return stitched_array else: blocks = [] for val in loop_axes[axis]: valed_axes = point_axes.copy() valed_axes[axis] = val blocks.append(recurse_axes(remaining_axes, valed_axes)) return blocks blocks = recurse_axes(self.axes, {}) if verbose: print( "\rStacking tiles... " ) # extra space otherwise there is no space after the "Adding data chunk {} {}" # import time # s = time.time() array = da.stack(blocks, allow_unknown_chunksizes=False) # e = time.time() # print(e - s) if verbose: print("\rDask array opened") return array def has_image( self, channel=0, z=None, time=None, position=None, channel_name=None, resolution_level=0, row=None, col=None, **kwargs ): """Check if this image is present in the dataset Parameters ---------- channel : int index of the channel, if applicable (Default value = None) z : int index of z slice, if applicable (Default value = None) time : int index of the time point, if applicable (Default value = None) position : int index of the XY position, if applicable (Default value = None) channel_name : str Name of the channel. Overrides channel index if supplied (Default value = None) row : int index of tile row for XY tiled datasets (Default value = None) col : int index of tile col for XY tiled datasets (Default value = None) resolution_level : 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) (Default value = 0) **kwargs Arbitrary keyword arguments Returns ------- bool : indicating whether the dataset has an image matching the specifications """ if self._remote_storage is not None: axes = self._bridge.construct_java_object("java.util.HashMap") for key in kwargs.keys(): axes.put(key, kwargs[key]) if row is not None and col is not None: return self._remote_storage.has_tile_by_row_col(axes, resolution_level, row, col) else: return self._remote_storage.has_image(axes, resolution_level) return self.res_levels[0].has_image( self._consolidate_axes(channel, channel_name, z, position, time, row, col, kwargs) ) def read_image( self, channel=0, z=None, time=None, position=None, row=None, col=None, channel_name=None, resolution_level=0, memmapped=False, **kwargs ): """ Read image data as numpy array Parameters ---------- channel : int index of the channel, if applicable (Default value = None) z : int index of z slice, if applicable (Default value = None) time : int index of the time point, if applicable (Default value = None) position : int index of the XY position, if applicable (Default value = None) channel_name : Name of the channel. Overrides channel index if supplied (Default value = None) row : int index of tile row for XY tiled datasets (Default value = None) col : int index of tile col for XY tiled datasets (Default value = None) resolution_level : 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) (Default value = 0) memmapped : bool (Default value = False) **kwargs : names and integer positions of any other axes Returns ------- image : numpy array or tuple image as a 2D numpy array, or tuple with image and image metadata as dict """ axes = self._consolidate_axes(channel, channel_name, z, position, time, row, col, kwargs) if self._remote_storage is not None: if memmapped: raise Exception("Memory mapping not available for in progress acquisitions") java_axes = self._bridge.construct_java_object("java.util.HashMap") for key in axes: java_axes.put(key, kwargs[key]) if not self._remote_storage.has_image(java_axes, resolution_level): return None tagged_image = self._remote_storage.get_image(axes, resolution_level) if resolution_level == 0: image = np.reshape( tagged_image.pix, newshape=[tagged_image.tags["Height"], tagged_image.tags["Width"]], ) if (self._tile_height is not None) and (self._tile_width is not None): # crop down to just the part that shows (i.e. no overlap) image = image[ (image.shape[0] - self._tile_height) // 2 : -(image.shape[0] - self._tile_height) // 2, (image.shape[1] - self._tile_width) // 2 : -(image.shape[1] - self._tile_width) // 2, ] else: image = np.reshape(tagged_image.pix, newshape=[self._tile_height, self._tile_width]) return image else: res_level = self.res_levels[resolution_level] return res_level.read_image(axes, memmapped) def read_metadata( self, channel=0, z=None, time=None, position=None, channel_name=None, row=None, col=None, resolution_level=0, **kwargs ): """ Read metadata only. Faster than using read_image to retrieve metadata Parameters ---------- channel : int index of the channel, if applicable (Default value = None) z : int index of z slice, if applicable (Default value = None) time : int index of the time point, if applicable (Default value = None) position : int index of the XY position, if applicable (Default value = None) channel_name : Name of the channel. Overrides channel index if supplied (Default value = None) row : int index of tile row for XY tiled datasets (Default value = None) col : int index of tile col for XY tiled datasets (Default value = None) resolution_level : 0 is full resolution, otherwise represents downampling of pixels at 2 ** (resolution_level) (Default value = 0) **kwargs : names and integer positions of any other axes Returns ------- metadata : dict """ axes = self._consolidate_axes(channel, channel_name, z, position, time, row, col, kwargs) if self._remote_storage is not None: java_axes = self._bridge.construct_java_object("java.util.HashMap") for key in axes: java_axes.put(key, kwargs[key]) if not self._remote_storage.has_image(java_axes, resolution_level): return None # TODO: could speed this up a lot on the Java side by only reading metadata instead of pixels too return self._remote_storage.get_image(axes, resolution_level).tags else: res_level = self.res_levels[resolution_level] return res_level.read_metadata(axes) def close(self): if self._remote_storage is not None: # nothing to do, this is handled on the java side return for res_level in self.res_levels: res_level.close() def get_channel_names(self): if self._remote_storage is not None: raise Exception("Not implemented for in progress datasets") return self._channel_names.keys() def _consolidate_axes(self, channel, channel_name, z, position, time, row, col, kwargs): axes = {} if channel is not None: axes[self._CHANNEL_AXIS] = channel if channel_name is not None: axes[self._CHANNEL_AXIS] = self._channel_names[channel_name] if z is not None: axes[self._Z_AXIS] = z if position is not None: axes[self._POSITION_AXIS] = position if time is not None: axes[self._TIME_AXIS] = time if row is not None: axes[self._ROW_AXIS] = row if col is not None: axes[self._COLUMN_AXIS] = col for other_axis_name in kwargs.keys(): axes[other_axis_name] = kwargs[other_axis_name] return axes