def mitos_in_neighborhood(mito_roi_source, neighborhood_origin_xyz,
                          neighborhood_id, mito_res_scale_diff):
    """
    Determine how many non-trivial mito objects overlap with the given "neighborhood object",
    and return a table of their IDs and sizes.

    1. Download the neighborhood mask for the given neighborhood_id.
    2. Erode the neighborhood mask by 1 px (see note in the comment above).
    3. Fetch the mito segmentation for the voxels within the neighborhood.
    4. Fetch (from dvid) the sizes of each mito object.
    5. Filter out the mitos that are smaller than the minimum size that is
       actually used in our published mito analyses.
    6. Just for additional info, determine how many connected components
       are formed by the mito objects.
    7. Return the mito IDs, sizses, and CC info as a DataFrame.
    """
    # The neighborhood segmentation source
    protocol, url = mito_roi_source.split('://')[-2:]
    server, uuid, instance = url.split('/')
    server = f'{protocol}://{server}'

    origin_zyx = np.array(neighborhood_origin_xyz[::-1])
    box = [origin_zyx - RADIUS, 1 + origin_zyx + RADIUS]

    # Align box to the analysis scale before scaling it.
    box = round_box(box, (2**ANALYSIS_SCALE))

    # Scale box
    box //= (2**ANALYSIS_SCALE)

    neighborhood_seg = fetch_labelmap_voxels(server,
                                             uuid,
                                             instance,
                                             box,
                                             scale=ANALYSIS_SCALE)
    neighborhood_mask = (neighborhood_seg == neighborhood_id)

    # This is equivalent to a 1-px erosion
    # See note above for why we do this.
    neighborhood_mask ^= binary_edge_mask(neighborhood_mask, 'inner')

    mito_seg = fetch_labelmap_voxels(*MITO_SEG,
                                     box,
                                     supervoxels=True,
                                     scale=ANALYSIS_SCALE -
                                     mito_res_scale_diff)
    assert neighborhood_mask.shape == mito_seg.shape
    mito_seg = np.where(neighborhood_mask, mito_seg, 0)

    # The mito segmentation includes little scraps and slivers
    # that were filtered out of the "real" mito set.
    # Filter those scraps out of our results here.
    mito_ids = set(pd.unique(mito_seg.ravel())) - {0}
    mito_sizes = fetch_sizes(*MITO_SEG, [*mito_ids], supervoxels=True)
    mito_sizes = mito_sizes.rename_axis('mito')
    mito_sizes *= (2**mito_res_scale_diff)**3

    # This is our main result: mito IDs (and their sizes)
    mito_sizes = mito_sizes.loc[mito_sizes >= MIN_MITO_SIZE]

    # Just for extra info, group the mitos we found into connected components.
    mito_mask = mask_for_labels(mito_seg, mito_sizes.index)
    mito_box = compute_nonzero_box(mito_mask)
    mito_mask = extract_subvol(mito_mask, mito_box)
    mito_seg = extract_subvol(mito_seg, mito_box)
    mito_cc = label(mito_mask, connectivity=1)
    ct = contingency_table(mito_seg, mito_cc).reset_index()
    ct = ct.rename(columns={
        'left': 'mito',
        'right': 'cc',
        'voxel_count': 'cc_size'
    })
    ct = ct.set_index('mito')
    mito_sizes = pd.DataFrame(mito_sizes).merge(ct,
                                                'left',
                                                left_index=True,
                                                right_index=True)
    return mito_sizes
def _fetch_body_mask(seg_src, primary_point_s0, body, search_cfg, tbar_points_s0, logger):
    """
    Fetch a mask for the given body around the given point, with the given radius.

    The mask will be downloaded using download_scale and then rescaled to
    the analysis_scale using continuity-preserving downsampling.

    The returned mask is NOT a binary (boolean) volume. Instead, a uint8 volume is returned,
    with labels 1 and 2, indicating which portion of the mask belongs to the body (2) and
    which portion was added due to dilation (1).

    Providing a non-binary result is convenient for debugging.  It is also used to
    restrict the voxels that are preserved when filtering mitos, if no valid_mitos
    are provided to that function.
    """
    (radius_s0, download_scale, analysis_scale,
        dilation_radius_s0, dilation_exclusion_buffer_s0, _is_final) = search_cfg

    scale_diff = analysis_scale - download_scale
    with Timer("Fetching body segmentation", logger):
        assert _have_flyemflows and isinstance(seg_src, VolumeService)
        p = np.asarray(primary_point_s0) // (2**download_scale)
        R = radius_s0 // (2**download_scale)
        seg_box = [p-R, p+R+1]

        # Align to 64-px for consistency with the dvid case,
        # and compatibility with the code below.
        seg_box = round_box(seg_box, 64 * (2**scale_diff), 'out')
        p_local = (p - seg_box[0])
        seg = seg_src.get_subvolume(seg_box, download_scale)

    # Extract mask
    raw_mask = (seg == body).view(np.uint8)
    del seg

    # Downsample mask conservatively, i.e. keeping 'on' pixels no matter what
    seg_box //= (2**scale_diff)
    p //= (2**scale_diff)
    p_local //= (2**scale_diff)
    raw_mask = downsample_mask(raw_mask, 2**scale_diff, 'or')

    # Due to downsampling effects in the original data, it's possible
    # that the main tbar fell off its body in the downsampled image.
    # Make sure it's part of the mask
    raw_mask[(*p_local,)] = True

    # The same is true for all of the other tbars that fall within the mask box.
    # Fix them all.
    tbar_points = tbar_points_s0 // (2**analysis_scale)
    in_box = (tbar_points >= seg_box[0]).all(axis=1) & (tbar_points < seg_box[1]).all(axis=1)
    tbar_points = tbar_points[in_box]
    local_tbar_points = tbar_points - seg_box[0]
    raw_mask[(*local_tbar_points.transpose(),)] = True

    assert raw_mask.dtype == bool
    raw_mask = vigra.taggedView(raw_mask.view(np.uint8), 'zyx')

    dilation_buffer = dilation_exclusion_buffer_s0 // (2**analysis_scale)
    dilation_box = np.array((seg_box[0] + dilation_buffer, seg_box[1] - dilation_buffer))

    # Shrink to fit the data.
    local_box = compute_nonzero_box(raw_mask)
    raw_mask = raw_mask[box_to_slicing(*local_box)]
    mask_box = local_box + seg_box[0]
    p_local -= local_box[0]

    # Fill gaps
    repaired_mask = _fill_gaps(raw_mask, mask_box, analysis_scale, dilation_radius_s0, dilation_box)

    # Label the voxels:
    # 1: filled gaps
    # 2: filled gaps AND raw
    assert raw_mask.dtype == repaired_mask.dtype == np.uint8
    body_mask = np.where(repaired_mask, raw_mask + repaired_mask, 0)

    assert body_mask[(*p - mask_box[0],)]
    return body_mask, mask_box
def _crop_body_mask_and_mito_seg(body_mask, mito_seg, mask_box, search_cfg, batch_tbars, primary_point, logger):
    """
    To reduce the size of the analysis volumes during distance computation
    (the most expensive step), we pre-filter out components of the body mask
    don't actually contain both points of interest and mito.

    If those segments don't even touch the volume edges,
    then any points on those segments can be safely marked 'done'
    if this is the final search config.
    """
    with Timer("Filtering components and cropping", logger):
        body_cc = labelMultiArrayWithBackground((body_mask != 0).view(np.uint8))

        # Keep only components which contain both mito and points
        tbar_points = batch_tbars[[*'zyx']].values // (2 ** search_cfg.analysis_scale)
        is_in_box = (tbar_points >= mask_box[0]).all(axis=1) & (tbar_points < mask_box[1]).all(axis=1)
        tbar_points = tbar_points[is_in_box]
        pts_local = tbar_points - mask_box[0]

        point_cc_df = batch_tbars.iloc[is_in_box][[*'zyx']].copy()
        point_cc_df['cc'] = body_cc[tuple(np.transpose(pts_local))]

        point_ccs = set(point_cc_df['cc'])
        mito_ccs = set(pd.unique(body_cc[mito_seg != 0]))
        keep_ccs = point_ccs & mito_ccs
        keep_mask = mask_for_labels(body_cc, keep_ccs)

        body_mask = np.where(keep_mask, body_mask, 0)
        mito_seg = np.where(keep_mask, mito_seg, 0)
        logger.info(f"Dropped {body_cc.max() - len(keep_ccs)} components, kept {len(keep_ccs)}")

        # Also determine the set of points which should be marked as hopeless,
        # due to a lack of mitos on their components.
        # Hopeless points are those which reside on hopeless components.
        # Hopeless components are any components that fall within the dilation
        # region and still ended up without mitos.
        hopeless_point_ids = []
        if search_cfg.is_final and (point_ccs - mito_ccs):
            with Timer("Identifying hopeless points", logger):
                # Calculate the region that is subject to repairs (dilation),
                # in local coordinates.
                dr = search_cfg.dilation_radius_s0 // (2 ** search_cfg.analysis_scale)
                buf = search_cfg.dilation_exclusion_buffer_s0 // (2 ** search_cfg.analysis_scale)
                buf += max(1, dr)
                R = search_cfg.radius_s0 // (2 ** search_cfg.analysis_scale)
                orig_box = np.array([primary_point - R, primary_point + R + 1])
                inner_box = orig_box + np.array([buf, -buf])[:, None]
                inner_box = box_intersection(mask_box, inner_box)
                inner_box = inner_box - mask_box[0]
                inner_vol = body_cc[box_to_slicing(*inner_box)]
                inner_ccs = set(pd.unique(inner_vol.ravel())) - {0}

                # Overwrite body_cc, we don't need it for anything else after this.
                body_cc[box_to_slicing(*inner_box)] = 0

                outer_ccs = set(body_cc.ravel())
                hopeless_ccs = (inner_ccs - outer_ccs) - mito_ccs  # noqa
                hopeless_point_ids = point_cc_df.query('cc in @hopeless_ccs').index

        # Shrink the volume bounding box to encompass only the
        # non-zero portion of the filtered body mask.
        nz_box = compute_nonzero_box(keep_mask)
        if not nz_box.any():
            return None, None, nz_box, hopeless_point_ids

        body_mask = body_mask[box_to_slicing(*nz_box)]
        mito_seg = mito_seg[box_to_slicing(*nz_box)]
        mask_box = mask_box[0] + nz_box

        return body_mask, mito_seg, mask_box, hopeless_point_ids
        def overwrite_box(box, lowres_mask):
            assert lowres_mask.dtype == np.bool
            assert not (box[0] % block_width).any()
            assert lowres_mask.any(), \
                "This function is supposed to be called on bricks that actually need masking"

            # Crop box and mask to only include the extent of the masked voxels
            nonzero_mask_box = compute_nonzero_box(lowres_mask)
            nonzero_mask_box = round_box(nonzero_mask_box,
                                         (block_width * 2**scale) // 2**5)
            lowres_mask = extract_subvol(lowres_mask, nonzero_mask_box)

            box = box[0] + (nonzero_mask_box * 2**(5 - scale))
            box = box.astype(np.int32)

            if scale <= 5:
                mask = upsample(lowres_mask, 2**(5 - scale))
            else:
                # Downsample, but favor UNmasked voxels
                mask = ~view_as_blocks(~lowres_mask, 3 *
                                       (2**(scale - 5), )).any(axis=(3, 4, 5))

            old_seg = input_service.get_subvolume(box, scale)

            assert mask.dtype == np.bool
            new_seg = old_seg.copy()
            new_seg[mask] = 0

            if (new_seg == old_seg).all():
                # It's possible that there are no changed voxels, but only
                # at high scales where the masked voxels were downsampled away.
                #
                # So if the original downscale pyramids are perfect,
                # then the following assumption ought to hold.
                #
                # But I'm commenting it out in case the DVID pyramid at scale 5
                # isn't pixel-perfect in some places.
                #
                # assert scale > 5

                return None

            def post_changed_blocks(old_seg, new_seg):
                # If we post the whole volume, we'll be overwriting blocks that haven't changed,
                # wasting space in DVID (for duplicate blocks stored in the child uuid).
                # Instead, we need to only post the blocks that have changed.

                # So, can't just do this:
                # output_service.write_subvolume(new_seg, box[0], scale)

                seg_diff = (old_seg != new_seg)
                block_diff = view_as_blocks(seg_diff, 3 * (block_width, ))

                changed_block_map = block_diff.any(axis=(3, 4, 5)).nonzero()
                changed_block_corners = box[0] + np.transpose(
                    changed_block_map) * block_width

                changed_blocks = view_as_blocks(
                    new_seg, 3 * (block_width, ))[changed_block_map]
                encoded_blocks = encode_labelarray_blocks(
                    changed_block_corners, changed_blocks)

                mgr = output_service.resource_manager_client
                with mgr.access_context(output_service.server, True, 1,
                                        changed_blocks.nbytes):
                    post_labelmap_blocks(*output_service.instance_triple,
                                         None,
                                         encoded_blocks,
                                         scale,
                                         downres=False,
                                         noindexing=True,
                                         throttle=False,
                                         is_raw=True)

            assert not (box % block_width).any(), \
                "Should not write partial blocks"

            post_changed_blocks(old_seg, new_seg)
            del new_seg

            if scale != 0:
                # Don't collect statistics for higher scales
                return None

            erased_seg = old_seg.copy()
            erased_seg[~mask] = 0

            block_shape = 3 * (input_service.block_width, )
            erased_stats_df = block_stats_for_volume(block_shape, erased_seg,
                                                     box)
            return erased_stats_df