def test_post_mappings(labelmap_setup): """ Test the wrapper function for the /mappings DVID API. """ dvid_server, dvid_repo, _merge_table_path, _mapping_path, _supervoxel_vol = labelmap_setup instance_info = DvidInstanceInfo(dvid_server, dvid_repo, 'segmentation') # Fetch the original mapping orig_mapping = fetch_mappings(*instance_info).sort_index() assert (orig_mapping.index == [2, 3, 4, 5]).all() assert (orig_mapping == 1).all() # see initialization in conftest.py # Make sure post_mappings does not REQUIRE the Series to be named orig_mapping.index.name = 'barfoo' orig_mapping.name = 'foobar' # Now post a new mapping and read it back. new_mapping = orig_mapping.copy() new_mapping[:] = 2 new_mapping.sort_index(inplace=True) # Post all but sv 5 post_mappings(*instance_info, new_mapping.iloc[:-1], mutid=1) fetched_mapping = fetch_mappings(*instance_info).sort_index() assert (fetched_mapping.index == [3, 4, 5]).all() assert (fetched_mapping.iloc[:-1] == 2).all() assert (fetched_mapping.iloc[-1:] == 1).all() # Now post sv 5, too post_mappings(*instance_info, new_mapping.iloc[-1:], mutid=1) fetched_mapping = fetch_mappings(*instance_info).sort_index() assert (fetched_mapping.index == [3, 4, 5]).all() assert (fetched_mapping == 2).all() # Try batched new_mapping = pd.Series(index=[1, 2, 3, 4, 5], data=[1, 1, 1, 2, 2]) post_mappings(*instance_info, new_mapping, mutid=1, batch_size=4) fetched_mapping = fetch_mappings(*instance_info).sort_index() assert (fetched_mapping.index == [2, 3, 4, 5]).all() assert (fetched_mapping == [1, 1, 2, 2]).all(), f"{fetched_mapping} != {[1,1,2,2]}" # Restore the original mapping post_mappings(*instance_info, orig_mapping, 1) fetched_mapping = fetch_mappings(*instance_info).sort_index() assert (fetched_mapping.index == [2, 3, 4, 5]).all() assert (fetched_mapping == 1).all()
def test_fetch_mappings(labelmap_setup): """ Test the wrapper function for the /mappings DVID API. """ dvid_server, dvid_repo, _merge_table_path, _mapping_path, _supervoxel_vol = labelmap_setup instance_info = DvidInstanceInfo(dvid_server, dvid_repo, 'segmentation') mapping = fetch_mappings(*instance_info) assert isinstance(mapping, pd.Series) assert mapping.index.name == 'sv' assert mapping.name == 'body' assert (sorted(mapping.index) == [2, 3, 4, 5] ) # Does not include 'identity' row for SV 1. See docstring. assert (mapping == 1).all() # see initialization in conftest.py
def main(): parser = argparse.ArgumentParser() parser.add_argument('--split-into-batches', type=int, help='If given, also split the body stats into this many batches of roughly equal size') parser.add_argument('server') parser.add_argument('src_uuid') parser.add_argument('labelmap_instance') parser.add_argument('supervoxel_block_stats_h5', help=f'An HDF5 file with a single dataset "stats", with dtype: {STATS_DTYPE[1:]} (Note: No column for body_id)') args = parser.parse_args() configure_default_logging() initialize_excepthook() (block_sv_stats, _presorted_by, _agglo_path) = load_stats_h5_to_records(args.supervoxel_block_stats_h5) src_info = (args.server, args.src_uuid, args.labelmap_instance) mapping = fetch_mappings(*src_info) assert isinstance(mapping, pd.Series) mapping_df = mapping.reset_index().rename(columns={'sv': 'segment_id', 'body': 'body_id'}) # sorts in-place, and saves a copy to hdf5 sort_block_stats( block_sv_stats, mapping_df, args.supervoxel_block_stats_h5[:-3] + '-sorted-by-body.h5', '<fetched-from-dvid>') if args.split_into_batches: num_batches = args.split_into_batches batch_size = int(np.ceil(len(block_sv_stats) / args.split_into_batches)) logger.info(f"Splitting into {args.split_into_batches} batches of size ~{batch_size}") os.makedirs('stats-batches', exist_ok=True) body_spans = groupby_spans_presorted(block_sv_stats['body_id'][:, None]) for batch_index, batch_spans in enumerate(tqdm_proxy(iter_batches(body_spans, batch_size))): span_start, span_stop = batch_spans[0][0], batch_spans[-1][1] batch_stats = block_sv_stats[span_start:span_stop] digits = int(np.ceil(np.log10(num_batches))) batch_path = ('stats-batches/stats-batch-{:0' + str(digits) + 'd}.h5').format(batch_index) save_stats(batch_stats, batch_path) logger.info("DONE sorting stats by body")
def sparse_brick_coords_for_labels(self, labels, clip=True): """ Return a DataFrame indicating the brick coordinates (starting corner) that encompass the given labels. Args: labels: A list of body IDs (if ``self.supervoxels`` is False), or supervoxel IDs (if ``self.supervoxels`` is True). clip: If True, filter the results to exclude any coordinates that fall outside this service's bounding-box. Otherwise, all brick coordinates that encompass the given labels will be returned, whether or not they fall within the bounding box. Returns: DataFrame with columns [z,y,x,label], where z,y,x represents the starting corner (in full-res coordinates) of a brick that contains the label. """ assert not isinstance(labels, set), "Pass labels as a list or array, not a set" labels = pd.unique(labels) is_supervoxels = self.supervoxels brick_shape = self.preferred_message_shape assert (brick_shape % self.block_width == 0).all(), \ ("Brick shape ('preferred-message-shape') must be a multiple of the " f"block width ({self.block_width}) in all dimensions, not {brick_shape}") bad_labels = [] if not is_supervoxels: # No supervoxel filtering. # Sort by body, since that should be slightly nicer for dvid performance. bodies_and_svs = {label: None for label in sorted(labels)} else: # Arbitrary heuristic for whether to do the body-lookups on DVID or on the client. if len(labels) < 100_000: # If we're only dealing with a few supervoxels, # ask dvid to map them to bodies for us. mapping = fetch_mapping(*self.instance_triple, labels, as_series=True) else: # If we're dealing with a lot of supervoxels, ask for # the entire mapping, and look up the bodies ourselves. complete_mapping = fetch_mappings(*self.instance_triple) mapper = LabelMapper(complete_mapping.index.values, complete_mapping.values) labels = np.asarray(labels, np.uint64) bodies = mapper.apply(labels, True) mapping = pd.Series(index=labels, data=bodies, name='body') mapping.index.rename('sv', inplace=True) bad_svs = mapping[mapping == 0] bad_labels.extend(bad_svs.index.tolist()) # Group by body mapping = mapping[mapping != 0] grouped_svs = mapping.reset_index().groupby('body').agg( {'sv': list})['sv'] # Sort by body, since that should be slightly nicer for dvid performance. bodies_and_svs = grouped_svs.sort_index().to_dict() # Extract these to avoid pickling 'self' (just for speed) server, uuid, instance = self.instance_triple if self._use_resource_manager_for_sparse_coords: mgr = self.resource_manager_client else: mgr = ResourceManagerClient("", 0) def fetch_brick_coords(body, supervoxel_subset): """ Fetch the block coordinates for the given body, filter them for the given supervoxels (if any), and convert the block coordinates to brick coordinates. """ assert is_supervoxels or supervoxel_subset is None try: with mgr.access_context(server, True, 1, 1): labelindex = fetch_labelindex(server, uuid, instance, body, 'protobuf') coords_df = convert_labelindex_to_pandas(labelindex).blocks except HTTPError as ex: if (ex.response is not None and ex.response.status_code == 404): return (body, None) raise except RuntimeError as ex: if 'does not map to any body' in str(ex): return (body, None) raise if len(coords_df) == 0: return (body, None) if is_supervoxels: supervoxel_subset = set(supervoxel_subset) coords_df = coords_df.query('sv in @supervoxel_subset').copy() coords_df[['z', 'y', 'x']] //= brick_shape coords_df['body'] = np.uint64(body) coords_df.drop_duplicates(inplace=True) return (body, coords_df) def fetch_and_concatenate_brick_coords(bodies_and_supervoxels): """ To reduce the number of tiny DataFrames collected to the driver, it's best to concatenate the partitions first, on the workers, rather than a straightforward call to starmap(fetch_brick_coords). Hence, this function that consolidates each partition. """ bad_bodies = [] coord_dfs = [] for (body, supervoxel_subset) in bodies_and_supervoxels: _, coords_df = fetch_brick_coords(body, supervoxel_subset) if coords_df is None: bad_bodies.append(body) else: coord_dfs.append(coords_df) del coords_df if coord_dfs: return [(pd.concat(coord_dfs, ignore_index=True), bad_bodies)] else: return [(None, bad_bodies)] with Timer( f"Fetching coarse sparsevols for {len(labels)} labels ({len(bodies_and_svs)} bodies)", logger=logger): import dask.bag as db coords_and_bad_bodies = ( db.from_sequence( bodies_and_svs.items(), npartitions=4096 ) # Instead of fancy heuristics, just pick 4096 .map_partitions(fetch_and_concatenate_brick_coords).compute()) coords_df_partitions, bad_body_partitions = zip(*coords_and_bad_bodies) for body in chain(*bad_body_partitions): if is_supervoxels: bad_labels.extend(bodies_and_svs[body]) else: bad_labels.append(body) if bad_labels: name = 'sv' if is_supervoxels else 'body' pd.Series(bad_labels, name=name).to_csv('labels-without-sparsevols.csv', index=False, header=True) if len(bad_labels) < 100: msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels: {bad_labels}" else: msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels. See labels-without-sparsevols.csv" logger.warning(msg) coords_df_partitions = list( filter(lambda df: df is not None, coords_df_partitions)) if len(coords_df_partitions) == 0: raise RuntimeError( "Could not find bricks for any of the given labels") coords_df = pd.concat(coords_df_partitions, ignore_index=True) if self.supervoxels: coords_df['label'] = coords_df['sv'] else: coords_df['label'] = coords_df['body'] coords_df.drop_duplicates(['z', 'y', 'x', 'label'], inplace=True) coords_df[['z', 'y', 'x']] *= brick_shape if clip: # Keep if the last pixel in the brick is to the right of the bounding-box start # and the first pixel in the brick is to the left of the bounding-box stop keep = (coords_df[['z', 'y', 'x']] + brick_shape > self.bounding_box_zyx[0]).all(axis=1) keep &= (coords_df[['z', 'y', 'x']] < self.bounding_box_zyx[1]).all(axis=1) coords_df = coords_df.loc[keep] return coords_df[['z', 'y', 'x', 'label']]
*, threads=0, processes=0, last_mutid=None, mapping=None): assert not (threads and processes), \ "Use threads or processes (or neither), but not both." if last_mutid is None: last_mutid = fetch_repo_info(*src_info[:2])["MutationID"] (block_sv_stats, presorted_by, _agglo_path) = load_stats_h5_to_records(erased_block_stats_h5) if presorted_by != 'body_id': if mapping is None: mapping = fetch_mappings(*src_info) assert isinstance(mapping, pd.Series) mapping_df = mapping.reset_index().rename(columns={ 'sv': 'segment_id', 'body': 'body_id' }) # sorts in-place, and saves a copy to hdf5 sort_block_stats(block_sv_stats, mapping_df, erased_block_stats_h5[:-3] + '-sorted-by-body.h5', '<fetched-from-dvid>') if threads > 0: pool = ThreadPool(threads) elif processes > 0:
def fetch_roi_synapses(server, uuid, synapses_instance, rois, fetch_labels=False, return_partners=False, processes=16): """ Fetch the coordinates and (optionally) body labels for all synapses that fall within the given ROIs. Args: server: DVID server, e.g. 'emdata4:8900' uuid: DVID uuid, e.g. 'abc9' synapses_instance: DVID synapses instance name, e.g. 'synapses' rois: A single DVID ROI instance names or a list of them, e.g. 'EB' or ['EB', 'FB'] fetch_labels: If True, also fetch the supervoxel and body label underneath each synapse, returned in columns 'sv' and 'body'. return_partners: If True, also return the partners table. processes: How many parallel processes to use when fetching synapses and supervoxel labels. Returns: pandas DataFrame with columns: ``['z', 'y', 'x', 'kind', 'conf']`` and ``['sv', 'body']`` (if ``fetch_labels=True``) If return_partners is True, also return the partners table. Example: df = fetch_roi_synapses('emdata4:8900', '3c281', 'synapses', ['PB(L5)', 'PB(L7)'], True, 8) """ # Late imports to avoid circular imports in dvid/__init__ from neuclease.dvid import fetch_combined_roi_volume, determine_point_rois, fetch_labels_batched, fetch_mapping, fetch_mappings assert rois, "No rois provided, result would be empty. Is that what you meant?" if isinstance(rois, str): rois = [rois] # Determine name of the segmentation instance that's # associated with the given synapses instance. syn_info = fetch_instance_info(server, uuid, synapses_instance) seg_instance = syn_info["Base"]["Syncs"][0] logger.info(f"Fetching mask for ROIs: {rois}") # Fetch the ROI as a low-res array (scale 5, i.e. 32-px resolution) roi_vol_s5, roi_box_s5, overlapping_pairs = fetch_combined_roi_volume( server, uuid, rois) if len(overlapping_pairs) > 0: logger.warning( "Some ROIs overlapped and are thus not completely represented in the output:\n" f"{overlapping_pairs}") # Convert to full-res box roi_box = (2**5) * roi_box_s5 # fetch_synapses_in_batches() requires a box that is 64-px-aligned roi_box = round_box(roi_box, 64, 'out') logger.info("Fetching synapse points") # points_df is a DataFrame with columns for [z,y,x] points_df, partners_df = fetch_synapses_in_batches(server, uuid, synapses_instance, roi_box, processes=processes) # Append a 'roi_name' column to points_df logger.info("Labeling ROI for each point") determine_point_rois(server, uuid, rois, points_df, roi_vol_s5, roi_box_s5) logger.info("Discarding points that don't overlap with the roi") rois = {*rois} points_df = points_df.query('roi in @rois').copy() columns = ['z', 'y', 'x', 'kind', 'conf', 'roi_label', 'roi'] if fetch_labels: logger.info("Fetching supervoxel under each point") svs = fetch_labels_batched(server, uuid, seg_instance, points_df[['z', 'y', 'x']].values, supervoxels=True, processes=processes) with Timer("Mapping supervoxels to bodies", logger): # Arbitrary heuristic for whether to do the # body-lookups on DVID or on the client. if len(svs) < 100_000: bodies = fetch_mapping(server, uuid, seg_instance, svs) else: mapping = fetch_mappings(server, uuid, seg_instance) mapper = LabelMapper(mapping.index.values, mapping.values) bodies = mapper.apply(svs, True) points_df['sv'] = svs points_df['body'] = bodies columns += ['body', 'sv'] if return_partners: # Filter #partners_df = partners_df.query('post_id in @points_df.index and pre_id in @points_df.index').copy() # Faster filter (via merge) partners_df = partners_df.merge(points_df[[]], 'inner', left_on='pre_id', right_index=True) partners_df = partners_df.merge(points_df[[]], 'inner', left_on='post_id', right_index=True) return points_df[columns], partners_df else: return points_df[columns]