def get_sizes(label_ids):
   try:
      sizes_pd = fetch_sizes(*master_seg, label_ids, supervoxels=False)
   except HTTPError:
      s_empty_pd = pd.Series(index=label_ids, data=-1, dtype=int)
      s_empty_pd.name = 'size'
      s_empty_pd.index.name = 'body'
      return(s_empty_pd)
   else:
      return(sizes_pd)
Exemplo n.º 2
0
def post_empty_meshes(server,
                      uuid,
                      instance='segmentation_sv_meshes',
                      svs=[],
                      permit_large=False,
                      check_sizes=True):
    """
    Given a list of supervoxel ids (presumably for SMALL supervoxels),
    post an empty .drc file to the tarsupervoxels instance for each one.
    
    (By convention, we do not generally store meshes for very tiny meshes.
    Instead, we store empty mesh files (i.e. 0 bytes) in their place, and
    our proofreading tools understand this convention.)
    
    Since this function is generally supposed to be used with only small supervoxels,
    it will refuse to write empty files for any supervoxels larger than 100 voxels,
    unless you pass permit_large=True.
    """
    import tarfile
    from io import BytesIO
    from tqdm import tqdm
    from neuclease.dvid import fetch_sizes, post_load

    # Determine segmentation instance
    info = fetch_instance_info(server, uuid, instance)
    segmentation_instance = info["Base"]["Syncs"][0]

    sizes = None
    if check_sizes:
        sizes = fetch_sizes(server,
                            uuid,
                            segmentation_instance,
                            svs,
                            supervoxels=True)
        if (sizes > 1000).any():
            msg = "Some of those supervoxels are large ({sizes.max()} voxels)."
            if permit_large:
                logger.warning(msg)
            else:
                msg = f"Error: {msg} Pass permit_large=True if you really mean it."
                raise RuntimeError(msg)

    bio = BytesIO()
    tf = tarfile.TarFile('empty-svs.tar', 'w', bio)
    for sv in tqdm(svs):
        tf.addfile(tarfile.TarInfo(f'{sv}.drc'), BytesIO())

    post_load(server, uuid, instance, bio.getvalue())
    return sizes
def mitos_in_neighborhood(mito_roi_source, neighborhood_origin_xyz,
                          neighborhood_id, mito_res_scale_diff):
    """
    Determine how many non-trivial mito objects overlap with the given "neighborhood object",
    and return a table of their IDs and sizes.

    1. Download the neighborhood mask for the given neighborhood_id.
    2. Erode the neighborhood mask by 1 px (see note in the comment above).
    3. Fetch the mito segmentation for the voxels within the neighborhood.
    4. Fetch (from dvid) the sizes of each mito object.
    5. Filter out the mitos that are smaller than the minimum size that is
       actually used in our published mito analyses.
    6. Just for additional info, determine how many connected components
       are formed by the mito objects.
    7. Return the mito IDs, sizses, and CC info as a DataFrame.
    """
    # The neighborhood segmentation source
    protocol, url = mito_roi_source.split('://')[-2:]
    server, uuid, instance = url.split('/')
    server = f'{protocol}://{server}'

    origin_zyx = np.array(neighborhood_origin_xyz[::-1])
    box = [origin_zyx - RADIUS, 1 + origin_zyx + RADIUS]

    # Align box to the analysis scale before scaling it.
    box = round_box(box, (2**ANALYSIS_SCALE))

    # Scale box
    box //= (2**ANALYSIS_SCALE)

    neighborhood_seg = fetch_labelmap_voxels(server,
                                             uuid,
                                             instance,
                                             box,
                                             scale=ANALYSIS_SCALE)
    neighborhood_mask = (neighborhood_seg == neighborhood_id)

    # This is equivalent to a 1-px erosion
    # See note above for why we do this.
    neighborhood_mask ^= binary_edge_mask(neighborhood_mask, 'inner')

    mito_seg = fetch_labelmap_voxels(*MITO_SEG,
                                     box,
                                     supervoxels=True,
                                     scale=ANALYSIS_SCALE -
                                     mito_res_scale_diff)
    assert neighborhood_mask.shape == mito_seg.shape
    mito_seg = np.where(neighborhood_mask, mito_seg, 0)

    # The mito segmentation includes little scraps and slivers
    # that were filtered out of the "real" mito set.
    # Filter those scraps out of our results here.
    mito_ids = set(pd.unique(mito_seg.ravel())) - {0}
    mito_sizes = fetch_sizes(*MITO_SEG, [*mito_ids], supervoxels=True)
    mito_sizes = mito_sizes.rename_axis('mito')
    mito_sizes *= (2**mito_res_scale_diff)**3

    # This is our main result: mito IDs (and their sizes)
    mito_sizes = mito_sizes.loc[mito_sizes >= MIN_MITO_SIZE]

    # Just for extra info, group the mitos we found into connected components.
    mito_mask = mask_for_labels(mito_seg, mito_sizes.index)
    mito_box = compute_nonzero_box(mito_mask)
    mito_mask = extract_subvol(mito_mask, mito_box)
    mito_seg = extract_subvol(mito_seg, mito_box)
    mito_cc = label(mito_mask, connectivity=1)
    ct = contingency_table(mito_seg, mito_cc).reset_index()
    ct = ct.rename(columns={
        'left': 'mito',
        'right': 'cc',
        'voxel_count': 'cc_size'
    })
    ct = ct.set_index('mito')
    mito_sizes = pd.DataFrame(mito_sizes).merge(ct,
                                                'left',
                                                left_index=True,
                                                right_index=True)
    return mito_sizes
Exemplo n.º 4
0
def identify_mito_bodies(body_seg,
                         mito_binary,
                         box,
                         scale,
                         halo,
                         body_seg_dvid_src=None,
                         viewer=None,
                         res0=8,
                         resource_mgr_client=None):
    # Identify segments that are mostly mito
    ct = contingency_table(body_seg,
                           mito_binary).reset_index().rename(columns={
                               'left': 'body',
                               'right': 'is_mito'
                           })
    ct = ct.pivot(index='body', columns='is_mito',
                  values='voxel_count').fillna(0).rename(columns={
                      0: 'non_mito',
                      1: 'mito'
                  })
    if 'mito' not in ct or 'non_mito' not in ct:
        # Nothing to do if there aren't any mito voxels
        return None, None, None
    ct[['mito', 'non_mito']] *= ((2**scale)**3)

    ct['body_size_local'] = ct.eval('mito+non_mito')
    ct['mito_frac_local'] = ct.eval('mito/body_size_local')
    ct = ct.sort_values('mito_frac_local', ascending=False)

    # Also compute the halo vs. non-halo sizes of every body.
    central_box = (box - box[0]) + [[halo, halo, halo], [-halo, -halo, -halo]]
    central_body_seg = body_seg[box_to_slicing(*central_box)]
    central_sizes = (pd.Series(central_body_seg.ravel(
        'K')).value_counts().rename('body_size_central').rename_axis('body'))

    central_mask = np.ones(central_box[1] - central_box[0], bool)
    update_mask_layer(viewer, 'central-box', central_mask, scale,
                      central_box + box[0])

    ct = ct.merge(central_sizes, 'left', on='body').fillna(0)
    ct['halo_size'] = ct.eval('body_size_local - body_size_central')
    ct = ct.query('body != 0')

    # Immediately drop bodies that reside only in the halo
    ct = ct.query('body_size_central > 0').copy()

    # For the bodies that MIGHT pass the mito threshold (based on their local size)
    # fetch their global size, if a dvid source was provided.
    # If not, we'll just use the local size, which is less accurate but
    # faster since we've already got it.
    if body_seg_dvid_src is None:
        ct['body_size'] = ct['body_size_local']
    else:
        local_mito_bodies = ct.query(
            'mito_frac_local >= @MITO_EDGE_FRAC').index

        if resource_mgr_client is None:
            body_sizes = fetch_sizes(*body_seg_dvid_src,
                                     local_mito_bodies).rename('body_size')
        else:
            with resource_mgr_client.access_context(body_seg_dvid_src[0], True,
                                                    1, 1):
                body_sizes = fetch_sizes(*body_seg_dvid_src,
                                         local_mito_bodies).rename('body_size')

        ct = ct.merge(body_sizes, 'left', on='body')

    # Due to downsampling effects, bodies can be larger at scale-1 than at scale-0, especially for tiny volumes.
    ct['mito_frac_global_vol'] = np.minimum(ct.eval('mito/body_size'), 1.0)

    # Calculate the proportion of mito edge pixels
    body_edges = np.where(edge_mask(body_seg, 'both'), body_seg, np.uint64(0))
    edge_ct = contingency_table(
        body_edges, mito_binary).reset_index().rename(columns={
            'left': 'body',
            'right': 'is_mito'
        })
    edge_ct = edge_ct.pivot(index='body',
                            columns='is_mito',
                            values='voxel_count').fillna(0).rename(columns={
                                0: 'non_mito',
                                1: 'mito'
                            })

    # Surface area scales with square of resolution, not cube
    edge_ct[['mito', 'non_mito']] *= ((2**scale)**2)

    edge_ct['body_size_local'] = edge_ct.eval('mito+non_mito')
    edge_ct['mito_frac_local'] = edge_ct.eval('mito/body_size_local')
    edge_ct = edge_ct.sort_values('mito_frac_local', ascending=False)
    edge_ct = edge_ct.query('body != 0')

    full_ct = ct.merge(edge_ct, 'inner', on='body', suffixes=['_vol', '_edge'])
    q = ("body_size < @MAX_MITO_FRAGMENT_VOL"
         " and mito_frac_global_vol >= @MITO_VOL_FRAC"
         " and mito_frac_local_edge >= @MITO_EDGE_FRAC")
    filtered_ct = full_ct.query(q)

    mito_bodies = filtered_ct.index
    mito_bodies_mask = mask_for_labels(body_seg, mito_bodies)
    update_mask_layer(viewer, 'mito-bodies-mask', mito_bodies_mask, scale, box,
                      res0)

    if len(filtered_ct) == 0:
        return None, None, None

    return mito_bodies, mito_bodies_mask, filtered_ct.copy()
Exemplo n.º 5
0
    def execute(self):
        options = self.config["mitorepair"]
        mgr_config = self.config["resource-manager"]
        resource_mgr_client = ResourceManagerClient(mgr_config["server"],
                                                    mgr_config["port"])
        seg_service, mask_service = self.init_services(resource_mgr_client)

        labelmap_server = options["dvid-labelmap-size-src"]["server"]
        labelmap_uuid = options["dvid-labelmap-size-src"]["uuid"]
        labelmap_name = options["dvid-labelmap-size-src"]["segmentation-name"]
        if labelmap_server:
            assert labelmap_server and labelmap_uuid and labelmap_name, \
                "Invalid labelmap specification"
            body_seg_dvid_src = (labelmap_server, labelmap_uuid, labelmap_name)
        else:
            body_seg_dvid_src = None

        # Boxes are determined by the left volume/labels/roi
        chunk_shape = np.array(3 * (options["chunk-width-s0"], ))
        boxes = self.init_boxes(seg_service, options["roi"], chunk_shape)
        logger.info(f"Processing {len(boxes)} bricks in total.")

        with Timer(
                "Finding merges to repair mito fragmentation in the segmentation",
                logger):

            def process_box(central_box):
                fragment_table = mito_body_assignments_for_box(
                    seg_service,
                    mask_service,
                    central_box,
                    options["halo-width-s0"],
                    options["analysis-scale"],
                    body_seg_dvid_src,
                    resource_mgr_client=resource_mgr_client)
                return fragment_table

            # Compute block-wise, and drop empty results
            fragment_tables = db.from_sequence(
                boxes, partition_size=1).map(process_box).compute()
            fragment_tables = [
                *filter(lambda t: t is not None, fragment_tables)
            ]

        with Timer("Combining fragment tables", logger):
            combined_table = pd.concat(fragment_tables)
            with open('combined-fragment-table.pkl', 'wb') as f:
                pickle.dump(combined_table, f)

        with Timer("Selecting top merge for each mito body", logger):
            filtered_table = (combined_table[[
                'body_size_local_vol', 'hull_body'
            ]].query('hull_body != 0').sort_values(
                'body_size_local_vol',
                ascending=False).groupby('body').head(1))

        try:
            filtered_table = self.append_synapse_columns(
                filtered_table, options["neuprint"])
        except Exception as ex:
            logger.error(
                f"Was not able to append synapse data from neuprint:\n{ex}")

        if body_seg_dvid_src:
            with Timer("Fetching sizes from DVID"):
                hull_body_sizes = fetch_sizes(
                    *body_seg_dvid_src,
                    filtered_table['hull_body'].values,
                    processes=8)
                filtered_table['hull_body_size'] = hull_body_sizes.values
        else:
            # The 'body_size' column is misleading if it wasn't actually fetched from dvid.
            del filtered_table['body_size']

        with Timer("Writing unfiltered top choices", logger):
            with open('unfiltered-top-choices-table.pkl', 'wb') as f:
                pickle.dump(filtered_table, f)

        with Timer(
                "Filtering out ambiguously merged bodies and bodies with synapses",
                logger):
            # Some bodies might be identified as a "mito body" in one block and a "hull body" in another.
            # (In the hemibrain v1.1 dataset, this was the case for 0.06% of all identified mito fragments.)
            # Such cases typically occur in areas of bad segmentation where our repairs don't have much hope of helping anyway.
            # Furthermore, there's nothing that prevents the merge decisions (mito -> hull) from being cyclical in such cases.
            # We could "fix" the issue by coming up with a rule, e.g. force the smaller one to merge into the bigger one,
            # and then transitively merge until we get to a non-mito body, but that seems too complicated for such a
            # tiny fraction of cases.
            # So, just drop those rows.
            hull_bodies = filtered_table['hull_body'].unique()
            conflict_bodies = filtered_table.index.intersection(
                hull_bodies).unique()
            conflict_bodies
            filtered_table = filtered_table.query(
                'body not in @conflict_bodies')

            # We don't want to change the connectome. Drop mito bodies with synapses.
            if 'pre' in filtered_table.columns:
                filtered_table = filtered_table.query('pre == 0 and post == 0')

        with Timer("Writing final results", logger):
            filtered_table.to_csv('final-fragment-table.csv',
                                  header=True,
                                  index=True)
            with open('final-fragment-table.pkl', 'wb') as f:
                pickle.dump(filtered_table, f)

            if 'pre' in filtered_table.columns:
                # Also save in the format that can be loaded by a LabelmappedVolumeService
                (filtered_table.reset_index()[['body', 'hull_body'
                                               ]].rename(columns={
                                                   'body': 'orig',
                                                   'hull_body': 'new'
                                               }).to_csv('final-remapping.csv',
                                                         index=False,
                                                         header=True))
Exemplo n.º 6
0
def audit_traced_history_for_branch(server,
                                    branch='',
                                    instance='segmentation',
                                    start_uuid=None,
                                    final_uuid=None):
    """
    Audits a labelmap kafka log and dvid history to
    check for the following spurious events:
    
        Merges:
            - A traced body should never be merged into another
              non-traced body and lose its body id
            - Two traced bodies that are merged together should be flagged — this
              might happen if a Leaves is merged to a Roughly traced
        
        Cleaves/body-splits:
            - A traced body should never be split where the bigger
              piece (or close to half the size) is split off

    Note:
        This function does not yet handle split events,
        which leads to certain cases that will be missed here.
        Furthermore, the body sizes in the results are not guaranteed to be accurate.
    """
    if start_uuid is not None:
        start_uuid = expand_uuid(server, start_uuid)

    if final_uuid is not None:
        final_uuid = expand_uuid(server, final_uuid)

    branch_uuids = find_branch_nodes(server, final_uuid, branch)
    leaf_uuid = branch_uuids[-1]

    if start_uuid is None:
        # Start with the second uuid by default
        start_uuid = branch_uuids[1]
    else:
        assert start_uuid != branch_uuids[0], \
            "Can't start from the root uuid, since the size/status information before that is unknown."

    if final_uuid is None:
        final_uuid = branch_uuids[-1]

    start_uuid_index = branch_uuids.index(start_uuid)
    final_uuid_index = branch_uuids.index(final_uuid)
    prev_uuid = branch_uuids[start_uuid_index - 1]

    audit_uuids = branch_uuids[start_uuid_index:1 + final_uuid_index]
    msgs_df = read_labelmap_kafka_df(server,
                                     final_uuid,
                                     instance,
                                     drop_completes=True)

    ann_kv_log = read_kafka_messages(server, final_uuid,
                                     f'{instance}_annotations')
    ann_kv_log_df = kafka_msgs_to_df(ann_kv_log)
    ann_kv_log_df['body'] = ann_kv_log_df['key'].astype(np.uint64)

    ann_df = fetch_body_annotations(server, prev_uuid,
                                    f'{instance}_annotations')[['status']]

    split_events = fetch_supervoxel_splits(server, leaf_uuid, instance, 'dvid',
                                           'dict')

    bad_merge_events = []
    bad_cleave_events = []

    body_sizes = {}
    for cur_uuid in tqdm_proxy(audit_uuids):

        def get_body_size(body):
            if body in body_sizes:
                return body_sizes[body]
            try:
                body_sizes[body] = fetch_size(server, prev_uuid, instance,
                                              body)
                return body_sizes[body]
            except Exception:
                return 0

        logger.info(f"Auditing uuid '{cur_uuid}'")
        cur_msgs_df = msgs_df.query('uuid == @cur_uuid')
        cur_bad_merge_events = []
        cur_bad_cleave_events = []
        for row in tqdm_proxy(cur_msgs_df.itertuples(index=False),
                              total=len(cur_msgs_df),
                              leave=False):
            if row.target_body in ann_df.index:
                target_status = ann_df.loc[row.target_body, 'status']
            else:
                target_status = None

            # Check for merges that eliminate a traced body
            if row.action == "merge":
                target_body_size = get_body_size(row.target_body)
                # Check for traced bodies being merged INTO another
                _merged_labels = row.msg['Labels']
                bad_merges_df = ann_df.query(
                    'body in @_merged_labels and status in @TRACED_STATUSES')
                if len(bad_merges_df) > 0:
                    if target_status in TRACED_STATUSES:
                        cur_bad_merge_events.append(
                            ('merge-traced-to-traced', *row, target_status,
                             bad_merges_df.index.tolist(),
                             bad_merges_df['status'].tolist()))
                    else:
                        cur_bad_merge_events.append(
                            ('merge-traced-to-nontraced', *row, target_status,
                             bad_merges_df.index.tolist(),
                             bad_merges_df['status'].tolist()))

                # Update local body_sizes
                merged_size = sum([
                    get_body_size(merge_label)
                    for merge_label in row.msg['Labels']
                ])
                if target_body_size != 0:
                    body_sizes[row.target_body] += merged_size

            # Check for bodies that lost more than half of their size in a cleave
            if row.action == "cleave" and target_status in TRACED_STATUSES:
                target_body_size = get_body_size(row.target_body)
                if target_body_size == 0:
                    # Since we aren't assessing split events yet,
                    # it's possible that we don't have all the information we need to assess all bodies.
                    # Bodies that were created during splits are not tracked.
                    cur_bad_cleave_events.append(
                        ('failed-to-assess', *row, target_body_size, 0))
                else:
                    cleaved_sizes = fetch_sizes(server,
                                                cur_uuid,
                                                instance,
                                                row.msg['CleavedSupervoxels'],
                                                supervoxels=True)
                    for sv in tqdm_proxy(
                            cleaved_sizes[(cleaved_sizes == 0)].index,
                            leave=False):
                        cleaved_sizes.loc[sv] = fetch_retired_supervoxel_size(
                            server, leaf_uuid, instance, sv, split_events)

                    if cleaved_sizes.sum() > target_body_size / 2:
                        cur_bad_cleave_events.append(
                            ('large-cleave', *row, target_body_size,
                             cleaved_sizes.sum()))

                    # Update local body_sizes
                    body_sizes[row.target_body] -= cleaved_sizes.sum()
                    body_sizes[row.msg['CleavedLabel']] = cleaved_sizes.sum()

            # TODO: Check body split events, too.
            # We could apply the above message effects to ann_df,
            # but it doesn't actually matter.

        logger.info(
            f"Found {len(cur_bad_merge_events)} bad merges and {len(cur_bad_cleave_events)} bad cleaves"
        )
        bad_merge_events += cur_bad_merge_events
        bad_cleave_events += cur_bad_cleave_events

        # Rather than fetching the entire ann_df for the next uuid,
        # just update the keys that changed (according to kafka).

        # TODO: It would be interesting to compare the difference between
        #       statuses that we computed vs. the statuses in the new uuid
        updated_bodies = ann_kv_log_df.query('uuid == @cur_uuid')['body']
        ann_update_df = fetch_body_annotations(server, cur_uuid,
                                               f'{instance}_annotations',
                                               updated_bodies)[['status']]
        ann_df = pd.concat((ann_df, ann_update_df))
        ann_df = ann_df.loc[~(ann_df.index.duplicated(keep='last'))]

    bad_merges_df = pd.DataFrame(bad_merge_events,
                                 columns=[
                                     'reason', *msgs_df.columns,
                                     'target_status', 'traced_bodies',
                                     'traced_statuses'
                                 ])
    bad_cleaves_df = pd.DataFrame(
        bad_cleave_events,
        columns=['reason', *msgs_df.columns, 'target_size', 'cleave_size'])

    np.save('audit_bad_merges_df.npy', bad_merges_df.to_records(index=False))
    np.save('audit_bad_cleaves_df.npy', bad_cleaves_df.to_records(index=False))

    return bad_merges_df, bad_cleaves_df