def mitos_in_neighborhood(mito_roi_source, neighborhood_origin_xyz, neighborhood_id, mito_res_scale_diff): """ Determine how many non-trivial mito objects overlap with the given "neighborhood object", and return a table of their IDs and sizes. 1. Download the neighborhood mask for the given neighborhood_id. 2. Erode the neighborhood mask by 1 px (see note in the comment above). 3. Fetch the mito segmentation for the voxels within the neighborhood. 4. Fetch (from dvid) the sizes of each mito object. 5. Filter out the mitos that are smaller than the minimum size that is actually used in our published mito analyses. 6. Just for additional info, determine how many connected components are formed by the mito objects. 7. Return the mito IDs, sizses, and CC info as a DataFrame. """ # The neighborhood segmentation source protocol, url = mito_roi_source.split('://')[-2:] server, uuid, instance = url.split('/') server = f'{protocol}://{server}' origin_zyx = np.array(neighborhood_origin_xyz[::-1]) box = [origin_zyx - RADIUS, 1 + origin_zyx + RADIUS] # Align box to the analysis scale before scaling it. box = round_box(box, (2**ANALYSIS_SCALE)) # Scale box box //= (2**ANALYSIS_SCALE) neighborhood_seg = fetch_labelmap_voxels(server, uuid, instance, box, scale=ANALYSIS_SCALE) neighborhood_mask = (neighborhood_seg == neighborhood_id) # This is equivalent to a 1-px erosion # See note above for why we do this. neighborhood_mask ^= binary_edge_mask(neighborhood_mask, 'inner') mito_seg = fetch_labelmap_voxels(*MITO_SEG, box, supervoxels=True, scale=ANALYSIS_SCALE - mito_res_scale_diff) assert neighborhood_mask.shape == mito_seg.shape mito_seg = np.where(neighborhood_mask, mito_seg, 0) # The mito segmentation includes little scraps and slivers # that were filtered out of the "real" mito set. # Filter those scraps out of our results here. mito_ids = set(pd.unique(mito_seg.ravel())) - {0} mito_sizes = fetch_sizes(*MITO_SEG, [*mito_ids], supervoxels=True) mito_sizes = mito_sizes.rename_axis('mito') mito_sizes *= (2**mito_res_scale_diff)**3 # This is our main result: mito IDs (and their sizes) mito_sizes = mito_sizes.loc[mito_sizes >= MIN_MITO_SIZE] # Just for extra info, group the mitos we found into connected components. mito_mask = mask_for_labels(mito_seg, mito_sizes.index) mito_box = compute_nonzero_box(mito_mask) mito_mask = extract_subvol(mito_mask, mito_box) mito_seg = extract_subvol(mito_seg, mito_box) mito_cc = label(mito_mask, connectivity=1) ct = contingency_table(mito_seg, mito_cc).reset_index() ct = ct.rename(columns={ 'left': 'mito', 'right': 'cc', 'voxel_count': 'cc_size' }) ct = ct.set_index('mito') mito_sizes = pd.DataFrame(mito_sizes).merge(ct, 'left', left_index=True, right_index=True) return mito_sizes
def _fetch_body_mask(seg_src, primary_point_s0, body, search_cfg, tbar_points_s0, logger): """ Fetch a mask for the given body around the given point, with the given radius. The mask will be downloaded using download_scale and then rescaled to the analysis_scale using continuity-preserving downsampling. The returned mask is NOT a binary (boolean) volume. Instead, a uint8 volume is returned, with labels 1 and 2, indicating which portion of the mask belongs to the body (2) and which portion was added due to dilation (1). Providing a non-binary result is convenient for debugging. It is also used to restrict the voxels that are preserved when filtering mitos, if no valid_mitos are provided to that function. """ (radius_s0, download_scale, analysis_scale, dilation_radius_s0, dilation_exclusion_buffer_s0, _is_final) = search_cfg scale_diff = analysis_scale - download_scale with Timer("Fetching body segmentation", logger): assert _have_flyemflows and isinstance(seg_src, VolumeService) p = np.asarray(primary_point_s0) // (2**download_scale) R = radius_s0 // (2**download_scale) seg_box = [p-R, p+R+1] # Align to 64-px for consistency with the dvid case, # and compatibility with the code below. seg_box = round_box(seg_box, 64 * (2**scale_diff), 'out') p_local = (p - seg_box[0]) seg = seg_src.get_subvolume(seg_box, download_scale) # Extract mask raw_mask = (seg == body).view(np.uint8) del seg # Downsample mask conservatively, i.e. keeping 'on' pixels no matter what seg_box //= (2**scale_diff) p //= (2**scale_diff) p_local //= (2**scale_diff) raw_mask = downsample_mask(raw_mask, 2**scale_diff, 'or') # Due to downsampling effects in the original data, it's possible # that the main tbar fell off its body in the downsampled image. # Make sure it's part of the mask raw_mask[(*p_local,)] = True # The same is true for all of the other tbars that fall within the mask box. # Fix them all. tbar_points = tbar_points_s0 // (2**analysis_scale) in_box = (tbar_points >= seg_box[0]).all(axis=1) & (tbar_points < seg_box[1]).all(axis=1) tbar_points = tbar_points[in_box] local_tbar_points = tbar_points - seg_box[0] raw_mask[(*local_tbar_points.transpose(),)] = True assert raw_mask.dtype == bool raw_mask = vigra.taggedView(raw_mask.view(np.uint8), 'zyx') dilation_buffer = dilation_exclusion_buffer_s0 // (2**analysis_scale) dilation_box = np.array((seg_box[0] + dilation_buffer, seg_box[1] - dilation_buffer)) # Shrink to fit the data. local_box = compute_nonzero_box(raw_mask) raw_mask = raw_mask[box_to_slicing(*local_box)] mask_box = local_box + seg_box[0] p_local -= local_box[0] # Fill gaps repaired_mask = _fill_gaps(raw_mask, mask_box, analysis_scale, dilation_radius_s0, dilation_box) # Label the voxels: # 1: filled gaps # 2: filled gaps AND raw assert raw_mask.dtype == repaired_mask.dtype == np.uint8 body_mask = np.where(repaired_mask, raw_mask + repaired_mask, 0) assert body_mask[(*p - mask_box[0],)] return body_mask, mask_box
def _crop_body_mask_and_mito_seg(body_mask, mito_seg, mask_box, search_cfg, batch_tbars, primary_point, logger): """ To reduce the size of the analysis volumes during distance computation (the most expensive step), we pre-filter out components of the body mask don't actually contain both points of interest and mito. If those segments don't even touch the volume edges, then any points on those segments can be safely marked 'done' if this is the final search config. """ with Timer("Filtering components and cropping", logger): body_cc = labelMultiArrayWithBackground((body_mask != 0).view(np.uint8)) # Keep only components which contain both mito and points tbar_points = batch_tbars[[*'zyx']].values // (2 ** search_cfg.analysis_scale) is_in_box = (tbar_points >= mask_box[0]).all(axis=1) & (tbar_points < mask_box[1]).all(axis=1) tbar_points = tbar_points[is_in_box] pts_local = tbar_points - mask_box[0] point_cc_df = batch_tbars.iloc[is_in_box][[*'zyx']].copy() point_cc_df['cc'] = body_cc[tuple(np.transpose(pts_local))] point_ccs = set(point_cc_df['cc']) mito_ccs = set(pd.unique(body_cc[mito_seg != 0])) keep_ccs = point_ccs & mito_ccs keep_mask = mask_for_labels(body_cc, keep_ccs) body_mask = np.where(keep_mask, body_mask, 0) mito_seg = np.where(keep_mask, mito_seg, 0) logger.info(f"Dropped {body_cc.max() - len(keep_ccs)} components, kept {len(keep_ccs)}") # Also determine the set of points which should be marked as hopeless, # due to a lack of mitos on their components. # Hopeless points are those which reside on hopeless components. # Hopeless components are any components that fall within the dilation # region and still ended up without mitos. hopeless_point_ids = [] if search_cfg.is_final and (point_ccs - mito_ccs): with Timer("Identifying hopeless points", logger): # Calculate the region that is subject to repairs (dilation), # in local coordinates. dr = search_cfg.dilation_radius_s0 // (2 ** search_cfg.analysis_scale) buf = search_cfg.dilation_exclusion_buffer_s0 // (2 ** search_cfg.analysis_scale) buf += max(1, dr) R = search_cfg.radius_s0 // (2 ** search_cfg.analysis_scale) orig_box = np.array([primary_point - R, primary_point + R + 1]) inner_box = orig_box + np.array([buf, -buf])[:, None] inner_box = box_intersection(mask_box, inner_box) inner_box = inner_box - mask_box[0] inner_vol = body_cc[box_to_slicing(*inner_box)] inner_ccs = set(pd.unique(inner_vol.ravel())) - {0} # Overwrite body_cc, we don't need it for anything else after this. body_cc[box_to_slicing(*inner_box)] = 0 outer_ccs = set(body_cc.ravel()) hopeless_ccs = (inner_ccs - outer_ccs) - mito_ccs # noqa hopeless_point_ids = point_cc_df.query('cc in @hopeless_ccs').index # Shrink the volume bounding box to encompass only the # non-zero portion of the filtered body mask. nz_box = compute_nonzero_box(keep_mask) if not nz_box.any(): return None, None, nz_box, hopeless_point_ids body_mask = body_mask[box_to_slicing(*nz_box)] mito_seg = mito_seg[box_to_slicing(*nz_box)] mask_box = mask_box[0] + nz_box return body_mask, mito_seg, mask_box, hopeless_point_ids
def overwrite_box(box, lowres_mask): assert lowres_mask.dtype == np.bool assert not (box[0] % block_width).any() assert lowres_mask.any(), \ "This function is supposed to be called on bricks that actually need masking" # Crop box and mask to only include the extent of the masked voxels nonzero_mask_box = compute_nonzero_box(lowres_mask) nonzero_mask_box = round_box(nonzero_mask_box, (block_width * 2**scale) // 2**5) lowres_mask = extract_subvol(lowres_mask, nonzero_mask_box) box = box[0] + (nonzero_mask_box * 2**(5 - scale)) box = box.astype(np.int32) if scale <= 5: mask = upsample(lowres_mask, 2**(5 - scale)) else: # Downsample, but favor UNmasked voxels mask = ~view_as_blocks(~lowres_mask, 3 * (2**(scale - 5), )).any(axis=(3, 4, 5)) old_seg = input_service.get_subvolume(box, scale) assert mask.dtype == np.bool new_seg = old_seg.copy() new_seg[mask] = 0 if (new_seg == old_seg).all(): # It's possible that there are no changed voxels, but only # at high scales where the masked voxels were downsampled away. # # So if the original downscale pyramids are perfect, # then the following assumption ought to hold. # # But I'm commenting it out in case the DVID pyramid at scale 5 # isn't pixel-perfect in some places. # # assert scale > 5 return None def post_changed_blocks(old_seg, new_seg): # If we post the whole volume, we'll be overwriting blocks that haven't changed, # wasting space in DVID (for duplicate blocks stored in the child uuid). # Instead, we need to only post the blocks that have changed. # So, can't just do this: # output_service.write_subvolume(new_seg, box[0], scale) seg_diff = (old_seg != new_seg) block_diff = view_as_blocks(seg_diff, 3 * (block_width, )) changed_block_map = block_diff.any(axis=(3, 4, 5)).nonzero() changed_block_corners = box[0] + np.transpose( changed_block_map) * block_width changed_blocks = view_as_blocks( new_seg, 3 * (block_width, ))[changed_block_map] encoded_blocks = encode_labelarray_blocks( changed_block_corners, changed_blocks) mgr = output_service.resource_manager_client with mgr.access_context(output_service.server, True, 1, changed_blocks.nbytes): post_labelmap_blocks(*output_service.instance_triple, None, encoded_blocks, scale, downres=False, noindexing=True, throttle=False, is_raw=True) assert not (box % block_width).any(), \ "Should not write partial blocks" post_changed_blocks(old_seg, new_seg) del new_seg if scale != 0: # Don't collect statistics for higher scales return None erased_seg = old_seg.copy() erased_seg[~mask] = 0 block_shape = 3 * (input_service.block_width, ) erased_stats_df = block_stats_for_volume(block_shape, erased_seg, box) return erased_stats_df