Exemple #1
0
def copy_vnc_subvolume(box_zyx,
                       copy_grayscale=True,
                       copy_segmentation=True,
                       chunk_shape=(64, 64, 2048)):
    assert not (box_zyx % 64).any(), \
        "Only 64px block-aligned volumes can be copied."

    import numpy as np
    from neuclease.util import boxes_from_grid, tqdm_proxy, round_box
    from neuclease.dvid import find_master, fetch_raw, post_raw, fetch_subvol, post_labelmap_voxels

    vnc_master = ('emdata4:8200', find_master('emdata4:8200'))

    NUM_SCALES = 8
    num_voxels = np.prod(box_zyx[1] - box_zyx[0])

    if copy_grayscale:
        logger.info(
            f"Copying grayscale from box {box_zyx[:,::-1].tolist()} ({num_voxels/1e6:.1f} Mvox) for {NUM_SCALES} scales"
        )
        for scale in tqdm_proxy(range(NUM_SCALES)):
            if scale == 0:
                input_name = 'grayscalejpeg'
                output_name = 'local-grayscalejpeg'
            else:
                input_name = f'grayscalejpeg_{scale}'
                output_name = f'local-grayscalejpeg_{scale}'

            scaled_box_zyx = np.maximum(box_zyx // 2**scale, 1)
            scaled_box_zyx = round_box(scaled_box_zyx, 64, 'out')

            for chunk_box in tqdm_proxy(boxes_from_grid(scaled_box_zyx,
                                                        chunk_shape,
                                                        clipped=True),
                                        leave=False):
                chunk = fetch_subvol(*vnc_master,
                                     input_name,
                                     chunk_box,
                                     progress=False)
                post_raw(*vnc_master, output_name, chunk_box[0], chunk)

    if copy_segmentation:
        logger.info(
            f"Copying segmentation from box {box_zyx[:,::-1].tolist()} ({num_voxels/1e6:.2f} Mvox)"
        )
        for chunk_box in tqdm_proxy(
                boxes_from_grid(box_zyx, chunk_shape, clipped=True)):
            chunk = fetch_raw(*vnc_master,
                              'segmentation',
                              chunk_box,
                              dtype=np.uint64)
            post_labelmap_voxels(*vnc_master,
                                 'local-segmentation',
                                 chunk_box[0],
                                 chunk,
                                 downres=True)

        # TODO: Update label indexes?

    logger.info("DONE")
Exemple #2
0
def check_tarsupervoxels_status_via_missing(server, uuid, tsv_instance,
                                            bodies):
    """
    For the given bodies, query the given tarsupervoxels instance and return a
    DataFrame indicating which supervoxels are 'missing' from the instance.
    
    Bodies that no longer exist in the segmentation instance are ignored.

    This function uses the /missing endpoint, which incurs a disk read in DVID 
    for the LabelIndex of each body.
    """
    sv_body = []

    try:
        for body in tqdm_proxy(bodies):
            try:
                missing_svs = fetch_missing(server, uuid, tsv_instance, body)
            except requests.RequestException as ex:
                if 'has no supervoxels' in ex.args[0]:
                    continue
                else:
                    raise

            sv_body += [(sv, body) for sv in missing_svs]

    except KeyboardInterrupt:
        logger.warning(
            "Interrupted. Returning results so far.  Interrupt again to kill.")

    df = pd.DataFrame(sv_body, columns=['sv', 'body'], dtype=np.uint64)
    df.set_index('sv', inplace=True)
    return df['body']
def adjust_focused_points(server,
                          uuid,
                          instance,
                          assignment_json_data,
                          supervoxels=True,
                          max_search_scale=3):
    new_assignment_data = copy.deepcopy(assignment_json_data)
    new_tasks = new_assignment_data["task list"]

    for task in tqdm_proxy(new_tasks):
        sv_1 = task["supervoxel ID 1"]
        sv_2 = task["supervoxel ID 2"]

        coord_1 = np.array(task["supervoxel point 1"])
        coord_2 = np.array(task["supervoxel point 2"])

        if supervoxels:
            label_1 = sv_1
            label_2 = sv_2
        else:
            label_1, label_2 = fetch_mapping(server,
                                             uuid,
                                             instance, [sv_1, sv_2],
                                             as_series=True)

        avg_coord = (coord_1 + coord_2) // 2

        # Search until we find a scale in which the two touch, or give up.
        for scale in range(1 + max_search_scale):
            box_xyz = (avg_coord // (2**scale) - 64,
                       avg_coord // (2**scale) + 64)
            box_zyx = np.array(box_xyz)[:, ::-1]
            seg_vol = fetch_labelarray_voxels(server,
                                              uuid,
                                              instance,
                                              box_zyx,
                                              scale,
                                              supervoxels=supervoxels)

            adjusted_coords_zyx = find_best_plane(seg_vol, label_1, label_2)
            adjusted_coords_zyx = np.array(adjusted_coords_zyx)

            if not (adjusted_coords_zyx == -1).all():
                # Found it.
                adjusted_coords_zyx += box_zyx[0]
                adjusted_coords_zyx *= (2**scale)
                break

        if (adjusted_coords_zyx == -1).all():
            task["coordinate-status"] = "misplaced"
        else:
            task["supervoxel point 1"] = adjusted_coords_zyx[0, ::-1].tolist()
            task["supervoxel point 2"] = adjusted_coords_zyx[1, ::-1].tolist()
            task["coordinate-status"] = f"adjusted-at-scale-{scale}"

    return new_assignment_data
Exemple #4
0
def export_sparsevol(server, uuid, instance, neurons_df, scale=5, format='tiff', output_dir='.'):
    import os
    import vigra
    import numpy as np

    from neuclease.util import round_box, tqdm_proxy
    from neuclease.dvid import fetch_sparsevol, resolve_ref, fetch_volume_box, box_to_slicing

    uuid = resolve_ref(server, uuid)

    # Determine the segmentation bounding box at the given scale,
    # which is used as the mask shape.
    seg = (server, uuid, instance)
    box = round_box(fetch_volume_box(*seg), 64, 'out')
    box[0] = (0,0,0)
    box_scaled = box // 2**scale

    # How many digits will we need in each slice file name?
    digits = int(np.ceil(np.log10(box_scaled[1, 0])))

    # Export a mask stack for each group.
    groups = neurons_df.groupby('group', sort=False)
    num_groups = neurons_df['group'].nunique()
    group_prog = tqdm_proxy(groups, total=num_groups)
    for group, df in group_prog:
        group_prog.write(f'Group "{group}": Assembling mask')
        group_mask = np.zeros(box_scaled[1], dtype=bool)
        group_mask = vigra.taggedView(group_mask, 'zyx')

        # Overlay each body mask in the current group
        for body in tqdm_proxy(df['body'], leave=False):
            body_mask, mask_box = fetch_sparsevol(*seg, body, scale=scale, format='mask')
            group_mask[box_to_slicing(*mask_box)] |= body_mask

        # Write out the slice files
        group_prog.write(f'Group "{group}": Writing slices')
        d = f'{output_dir}/{group}.stack'
        os.makedirs(d, exist_ok=True)
        for z in tqdm_proxy(range(group_mask.shape[0]), leave=False):
            p = ('{d}/{z:' + f'0{digits}' + 'd}.{f}').format(d=d, z=z, f=format)
            vigra.impex.writeImage(group_mask[z].astype(np.uint8), p)
def renumber_groups(tabs_and_paths, output_dir, exclude_labels=[]):
    """
    Given a series of tab-wise label group CSVs,
    renumber the group IDs so that group IDs are not duplicated across tabs.
    """
    tables = []
    for tab, path in tqdm_proxy(tabs_and_paths.items()):
        tab_df = pd.read_csv(path, dtype=np.int64)
        assert tab_df.columns.tolist() == ['group', 'label']
        tab_df['tab'] = tab
        tables.append(tab_df)

    full_df = pd.concat(tables, ignore_index=True)
    full_df = full_df.query('label not in @exclude_labels')

    new_groups_df = full_df.drop_duplicates(['group',
                                             'tab']).reset_index(drop=True)
    new_groups_df.index.name = 'unique_group'
    new_groups_df = new_groups_df.reset_index()

    full_regrouped_df = full_df.merge(
        new_groups_df[['tab', 'group', 'unique_group']],
        'left',
        on=['tab', 'group'])

    full_regrouped_df = full_regrouped_df.drop(columns=['group']).rename(
        columns={'unique_group': 'group'})
    full_regrouped_df['group'] += 1

    os.makedirs(output_dir, exist_ok=True)
    for tab, tab_df in tqdm_proxy(full_regrouped_df.groupby('tab'),
                                  total=len(tabs_and_paths)):
        tab_df[['group', 'label'
                ]].to_csv(f'{output_dir}/renumbered-groups-tab{tab}.csv',
                          header=True,
                          index=False)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--split-into-batches', type=int,
                        help='If given, also split the body stats into this many batches of roughly equal size')
    parser.add_argument('server')
    parser.add_argument('src_uuid')
    parser.add_argument('labelmap_instance')
    parser.add_argument('supervoxel_block_stats_h5',
                        help=f'An HDF5 file with a single dataset "stats", with dtype: {STATS_DTYPE[1:]} (Note: No column for body_id)')
    args = parser.parse_args()

    configure_default_logging()
    initialize_excepthook()
    (block_sv_stats, _presorted_by, _agglo_path) = load_stats_h5_to_records(args.supervoxel_block_stats_h5)

    src_info = (args.server, args.src_uuid, args.labelmap_instance)
    mapping = fetch_mappings(*src_info)

    assert isinstance(mapping, pd.Series)
    mapping_df = mapping.reset_index().rename(columns={'sv': 'segment_id', 'body': 'body_id'})

    # sorts in-place, and saves a copy to hdf5
    sort_block_stats( block_sv_stats,
                      mapping_df,
                      args.supervoxel_block_stats_h5[:-3] + '-sorted-by-body.h5',
                      '<fetched-from-dvid>')
    
    if args.split_into_batches:
        num_batches = args.split_into_batches
        batch_size = int(np.ceil(len(block_sv_stats) / args.split_into_batches))
        logger.info(f"Splitting into {args.split_into_batches} batches of size ~{batch_size}")
        os.makedirs('stats-batches', exist_ok=True)
        
        body_spans = groupby_spans_presorted(block_sv_stats['body_id'][:, None])
        for batch_index, batch_spans in enumerate(tqdm_proxy(iter_batches(body_spans, batch_size))):
            span_start, span_stop = batch_spans[0][0], batch_spans[-1][1]
            batch_stats = block_sv_stats[span_start:span_stop]
            digits = int(np.ceil(np.log10(num_batches)))
            batch_path = ('stats-batches/stats-batch-{:0' + str(digits) + 'd}.h5').format(batch_index)
            save_stats(batch_stats, batch_path)
    
    logger.info("DONE sorting stats by body")
def main():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-n', type=int, default=8)
    parser.add_argument('src')
    parser.add_argument('dest')
    args = parser.parse_args()

    args.src = abspath(args.src)
    args.dest = abspath(args.dest)

    if exists(args.dest):
        sys.exit(f"Error: Destination already exists: {args.dest}")

    from neuclease.util import compute_parallel, tqdm_proxy
    from neuclease import configure_default_logging
    configure_default_logging()

    os.chdir(args.src)
    logger.info("Listing source files")
    r = subprocess.run('find . -type f', shell=True, capture_output=True)
    src_paths = r.stdout.decode('utf-8').strip().split('\n')
    dest_paths = [f'{args.dest}/{p}' for p in src_paths]
    dest_dirs = sorted(set([*map(dirname, dest_paths)]))

    logger.info("Initializing directory tree")
    for d in tqdm_proxy(dest_dirs):
        os.makedirs(d, exist_ok=True)

    logger.info(f"Copying {len(src_paths)} files")
    compute_parallel(copyfile, [*zip(src_paths, dest_paths)],
                     10,
                     starmap=True,
                     ordered=False,
                     processes=args.n)

    logger.info("DONE")
    'id': '772a59939a8a65f95bb7f2e27dfe544a616ba15c'
}]

# Load boxes
ann_boxes_xyz = np.array([(a['pointA'], a['pointB']) for a in ng_ann])
ann_boxes_zyx = ann_boxes_xyz[..., ::-1]
boxes_zyx = np.zeros_like(ann_boxes_zyx)
boxes_zyx[:, 0, :] = ann_boxes_zyx.min(axis=1)
boxes_zyx[:, 1, :] = ann_boxes_zyx.max(axis=1)
assert (boxes_zyx[:, 1] - boxes_zyx[:, 0] > 0).all()

# Read segmentation and sizes
vnc_seg = ('emdata4:8450', '75d3ddd2e9e143a38fa9cc9e7d55b3d1', 'segmentation')
dfs = []
segs = []
for box in tqdm_proxy(boxes_zyx):
    seg = fetch_labelmap_voxels_chunkwise(*vnc_seg, box, threads=8)
    vc = pd.Series(seg.reshape(-1)).value_counts()
    df = vc.rename('boxed_size').rename_axis('body').reset_index()
    df['full_size'] = fetch_sizes(*vnc_seg,
                                  df['body'],
                                  batch_size=100,
                                  threads=8).values
    segs.append(seg)
    dfs.append(df)

for box, df in zip(boxes_zyx, dfs):
    name = '-'.join(map(str, box[0, ::-1].tolist()))
    df = df.sort_values('boxed_size', ascending=False)
    df.to_csv(f'seg-box-stats-{name}.csv', df, index=False, header=True)
Exemple #9
0
    def execute(self):
        options = self.config["mitodistances"]
        output_dir = self.config["output-directory"]
        body_svc, mito_svc = self.init_services()

        # Resource manager context must be initialized before resource manager client
        # (to overwrite config values as needed)
        dvid_mgr_config = self.config["dvid-access-manager"]
        dvid_mgr_context = LocalResourceManager(dvid_mgr_config)
        dvid_mgr_client = ResourceManagerClient(dvid_mgr_config["server"],
                                                dvid_mgr_config["port"])

        syn_server, syn_uuid, syn_instance = (options['synapse-criteria'][k]
                                              for k in ('server', 'uuid',
                                                        'instance'))
        syn_conf = float(options['synapse-criteria']['confidence'])
        syn_types = ['PreSyn', 'PostSyn']
        if options['synapse-criteria']['type'] == 'pre':
            syn_types = ['PreSyn']
        elif options['synapse-criteria']['type'] == 'post':
            syn_types = ['PostSyn']

        bodies = load_body_list(options["bodies"], False)
        skip_flags = [
            os.path.exists(f'{output_dir}/{body}.csv') for body in bodies
        ]
        bodies_df = pd.DataFrame({'body': bodies, 'should_skip': skip_flags})
        bodies = bodies_df.query('not should_skip')['body']

        # Shuffle for better load balance?
        # TODO: Would be better to sort by synapse count, and put large bodies first,
        #       assigned to partitions in round-robin style.
        #       Then work stealing will be more effective at knocking out the smaller jobs at the end.
        #       This requires knowing all the body sizes, though.
        #       Perhaps mito count would be a decent proxy for synapse count, and it's readily available.
        #bodies = bodies.sample(frac=1.0).values

        os.makedirs('body-logs')
        os.makedirs(output_dir, exist_ok=True)

        mito_server, mito_uuid, mito_instance = (options['mito-labelmap'][k]
                                                 for k in ('server', 'uuid',
                                                           'instance'))

        @auto_retry(3)
        def _fetch_synapses(body):
            with dvid_mgr_client.access_context(syn_server, True, 1, 1):
                syn_df = fetch_annotation_label(syn_server,
                                                syn_uuid,
                                                syn_instance,
                                                body,
                                                format='pandas')
                if len(syn_df) == 0:
                    return syn_df
                syn_types, syn_conf
                syn_df = syn_df.query(
                    'kind in @syn_types and conf >= @syn_conf').copy()
                return syn_df[[*'xyz', 'kind', 'conf'
                               ]].sort_values([*'xyz']).reset_index(drop=True)

        @auto_retry(3)
        def _fetch_mito_ids(body):
            with dvid_mgr_client.access_context(mito_server, True, 1, 1):
                try:
                    return fetch_supervoxels(mito_server, mito_uuid,
                                             mito_instance, body)
                except HTTPError:
                    return []

        def process_and_save(body):
            tbars = _fetch_synapses(body)
            valid_mitos = _fetch_mito_ids(body)

            # TODO:
            #   Does the stdout_redirected() mechanism work correctly in the context of multiprocessing?
            #   If not, I should probably just use a custom logging handler instead.
            with open(f"body-logs/{body}.log",
                      "w") as f, stdout_redirected(f), Timer() as timer:
                processed_tbars = []
                if len(tbars) == 0:
                    logging.getLogger(__name__).warning(
                        f"Body {body}: No synapses found")

                if len(valid_mitos) == 0:
                    logging.getLogger(__name__).warning(
                        f"Body {body}: Failed to fetch mito supervoxels")
                    processed_tbars = initialize_results(body, tbars)

                if len(valid_mitos) and len(tbars):
                    processed_tbars = measure_tbar_mito_distances(
                        body_svc,
                        mito_svc,
                        body,
                        tbars=tbars,
                        valid_mitos=valid_mitos)

            if len(processed_tbars) > 0:
                processed_tbars.to_csv(f'{output_dir}/{body}.csv',
                                       header=True,
                                       index=False)
                with open(f'{output_dir}/{body}.pkl', 'wb') as f:
                    pickle.dump(processed_tbars, f)

            if len(tbars) == 0:
                return (body, 0, 'no-synapses', timer.seconds)

            if len(valid_mitos) == 0:
                return (body, len(processed_tbars), 'no-mitos', timer.seconds)

            return (body, len(tbars), 'success', timer.seconds)

        logger.info(
            f"Processing {len(bodies)}, skipping {bodies_df['should_skip'].sum()}"
        )

        def process_batch(bodies):
            return [*map(process_and_save, bodies)]

        with dvid_mgr_context:
            batch_size = max(1, len(bodies) // 10_000)
            futures = self.client.map(process_batch,
                                      iter_batches(bodies, batch_size))

            # Support synchronous testing with a fake 'as_completed' object
            if hasattr(self.client, 'DEBUG'):
                ac = as_completed_synchronous(futures, with_results=True)
            else:
                ac = distributed.as_completed(futures, with_results=True)

            try:
                results = []
                for f, r in tqdm_proxy(ac, total=len(futures)):
                    results.extend(r)
            finally:
                results = pd.DataFrame(
                    results,
                    columns=['body', 'synapses', 'status', 'processing_time'])
                results.to_csv('results-summary.csv', header=True, index=False)
                num_errors = len(results.query('status == "error"'))
                if num_errors:
                    logger.warning(
                        f"Encountered {num_errors} errors. See results-summary.csv"
                    )
Exemple #10
0
def main():
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--annotation-instance', default='synapses')
    parser.add_argument('--labelmap-instance', default='segmentation')
    parser.add_argument('--labelsz-instance')
    parser.add_argument('server')
    parser.add_argument('uuid')
    parser.add_argument('elements_json')

    args = parser.parse_args()

    server = args.server
    uuid = args.uuid
    syn_instance = args.annotation_instance
    seg_instance = args.labelmap_instance

    ##
    ## 1. Create an 'annotation' instance to store the synapse data
    ##
    ##      POST .../instance
    ## 
    create_instance(server, uuid, syn_instance, 'annotation')

    ##
    ## 2. Upload the synapse elements.
    ##
    ##      POST .../elements
    ##
    ##    Note:
    ##      DVID stores these in block-aligned groups, based on the synapse coordinates.
    ##      Ingestion will be fastest if you pre-sort your JSON elements by 64px blocks,
    ##      in Z-Y-X order, as shown below.
    ##

    with open(args.elements_json, 'r') as f:
        elements = ujson.load(f)
    
    # Sort elements by block location (64px blocks)
    # FIXME: This code should work but I haven't tested it yet.  Might have a typo.
    elements_df = pd.DataFrame([(*e["Pos"], e) for e in elements], columns=['x', 'y', 'z', 'element'])
    elements_df[['z', 'y', 'x']] //= 64
    elements_df.sort_values(['z', 'y', 'x'], inplace=True)
    
    # Group blocks into larger chunks, with each chunk being 100 blocks
    # in the X direction (and 1x1 in the zy directions).
    elements_df['x'] //= 100

    # Ingest in chunks.
    num_chunks = elements_df[['z', 'y', 'x']].drop_duplicates().shape[0]
    chunked_df = elements_df.groupby(['z', 'y', 'x'])
    for _zyx, batch_elements_df in tqdm_proxy(chunked_df, total=num_chunks):
        post_elements(server, uuid, syn_instance, batch_elements_df['element'].tolist())
    
    ##
    ## 3. Sync the annotation instance to a pre-existing
    ##    segmentation (labelmap) instance.
    ##
    ##      POST .../sync
    ##
    ##    This 'subscribes' the annotation instance to changes in the segmentation,
    ##    keeping updated counts of synapses in each body.
    ##    This will enable the .../<annotation>/labels endpoint to work efficiently.
    ##
    post_sync(server, uuid, syn_instance, [seg_instance])
    
    ##
    ## 4. Reload the synapse instance AFTER the sync was configured (above).
    ##    For real-world data sizes (e.g. millions of synapses) this will take
    ##    a long time (hours).
    ##
    ##      POST .../reload
    ##
    post_reload(server, uuid, syn_instance)

    ##
    ## 5. (Optional)
    ##    For some proofreading protocols, you may wish to create a 'labelsz' (label size) instance,
    ##    which allows you to ask for the largest N bodies (by synapse count).
    ##
    ##
    if args.labelsz_instance:
        create_instance(server, uuid, args.labelsz_instance, 'labelsz')
        post_sync(server, uuid, args.labelsz_instance, [syn_instance])
        post_reload(server, uuid, args.labelsz_instance)
def measure_tbar_mito_distances(seg_src,
                                mito_src,
                                body,
                                *,
                                search_configs=DEFAULT_SEARCH_CONFIGS,
                                npclient=None,
                                tbars=None,
                                valid_mitos=None):
    """
    Search for the closest mito to each tbar in a list of tbars
    (or any set of points, really).

    FIXME: Rename this function.  It works for more than just tbars.

    Args:
        seg_src:
            (server, uuid, instance) OR a flyemflows VolumeService
            Labelmap instance for the neuron segmentation.
        mito_src:
            (server, uuid, instance) OR a flyemflows VolumeService
            Labelmap instance for the mitochondria "supervoxel"
            segmentation -- not just the "masks".
        body:
            The body ID of interest, on which the tbars reside.
        search_configs:
            A list ``SearchConfig`` tuples.
            For each tbar, this function tries to locate a mitochondria within
            he given search radius, using data downloaded from the given scale.
            If the search fails and no mito can be found, the function tries
            again using the next search criteria in the list.
            The radius should always be specified in scale-0 units,
            regardless of the scale at which you want to perform the analysis.
            Additionally, the data will be downloaded at the specified scale,
            then downsampled (with continuity preserving downsampling) to a lower scale for analysis.
            Notes:
                - Scale 4 is too low-res.  Stick with scale-3 or better.
                - Higher radius is more expensive, but some of that expense is
                  recouped because all points that fall within the radius are
                  analyzed at once.  See _measure_tbar_mito_distances()
                  implementation for details.
            dilation_radius_s0:
                If dilation_radius_s0 is non-zero, the segmentation will be "repaired" to close
                gaps, using a procedure involving a dilation of the given radius.
            dilation_exclusion_buffer_s0:
                We want to close small gaps in the segmentation, but only if we think
                they're really a gap in the actual segmentation, not if they are merely
                fingers of the same branch that are actually connected outside of our
                analysis volume. The dilation procedure tends to form such spurious
                connections near the volume border, so this parameter can be used to
                exclude a buffer (inner halo) near the border from dilation repairs.
        npclient:
            ``neuprint.Client`` to use when fetching the list of tbars that belong
            to the given body, unless you provide your own tbar points in the next
            argument.
        tbars:
            A DataFrame of tbar coordinates at least with columns ``['x', 'y', 'z']``.
        valid_mitos:
            If provided, only the listed mito IDs will be considered valid as search targets.
    Returns:
        DataFrame of tbar coordinates, mito distances, and mito coordinates.
        Points for which no nearby mito could be found (after trying all the given search_configs)
        will be marked with `done=False` in the results.
    """
    assert search_configs[-1].is_final, "Last search config should be marked is_final"
    assert all([not cfg.is_final for cfg in search_configs[:-1]]), \
        "Only the last search config should be marked is_final (no others)."

    # Fetch tbars
    if tbars is None:
        tbars = fetch_synapses(body, SC(type='pre', primary_only=True), client=npclient)

    tbars = initialize_results(body, tbars)

    if valid_mitos is None or len(valid_mitos) == 0:
        valid_mito_mapper = None
    else:
        valid_mitos = np.asarray(valid_mitos, dtype=np.uint64)
        valid_mito_mapper = LabelMapper(valid_mitos, valid_mitos)

    with tqdm_proxy(total=len(tbars)) as progress:
        for row in tbars.itertuples():
            # can't use row.done -- itertuples might be out-of-sync
            done = tbars['done'].loc[row.Index]
            if done:
                continue

            loop_logger = None
            for cfg in search_configs:
                prefix = (f"({row.x}, {row.y}, {row.z}) [ds={cfg.download_scale} "
                          f"as={cfg.analysis_scale} r={cfg.radius_s0:4} dil={cfg.dilation_radius_s0:2}] ")
                loop_logger = PrefixedLogger(logger, prefix)

                prev_num_done = tbars['done'].sum()
                _measure_tbar_mito_distances(
                    seg_src, mito_src, body, tbars, row.Index, cfg, valid_mito_mapper, loop_logger)
                num_done = tbars['done'].sum()

                progress.update(num_done - prev_num_done)
                done = tbars['done'].loc[row.Index]
                if done:
                    break

                if not cfg.is_final:
                    loop_logger.info("Search failed for primary tbar. Trying next search config!")

            if not done:
                loop_logger.warning(f"Failed to find a nearby mito for tbar at point {(row.x, row.y, row.z)}")
                progress.update(1)

    failed = np.isinf(tbars['mito-distance'])
    succeeded = ~failed
    logger.info(f"Found mitos for {succeeded.sum()} tbars, failed for {failed.sum()} tbars")

    return tbars
Exemple #12
0
        
        num_threads:
            How many threads to use, for parallel loading.
    """
    _check_instance(server, uuid, instance_name)
    block_sv_stats = sorted_block_sv_stats

    # 'processor' is declared as a global so it can be shared with
    # subprocesses quickly via implicit memory sharing via after fork()
    global processor
    instance_info = (server, uuid, instance_name)
    processor = StatsBatchProcessor(last_mutid, instance_info, tombstone_mode, block_sv_stats, subset_labels, check_mismatches)

    gen = generate_stats_batch_spans(block_sv_stats, batch_rows)

    progress_bar = tqdm_proxy(total=len(block_sv_stats), logger=logger)

    all_mismatch_ids = []
    all_missing_ids = []
    pool = multiprocessing.Pool(num_threads)
    with progress_bar, pool:
        # Rather than call pool.imap_unordered() with processor.process_batch(),
        # we use globally declared process_batch(), as explained below.
        for next_stats_batch_total_rows, batch_mismatches, batch_missing in pool.imap_unordered(process_batch, gen):
            if batch_mismatches:
                pd.Series(batch_mismatches).to_csv(f'labelindex-mismatches-{uuid}.csv', index=False, header=False, mode='a')
                all_mismatch_ids.extend( batch_mismatches )

            if batch_missing:
                pd.Series(batch_missing).to_csv(f'labelindex-missing-{uuid}.csv', index=False, header=False, mode='a')
                all_missing_ids.extend( batch_missing )
Exemple #13
0
def select_hulls_for_mito_bodies(mito_body_ct,
                                 mito_bodies_mask,
                                 mito_binary,
                                 body_seg,
                                 hull_masks,
                                 seed_bodies,
                                 box,
                                 scale,
                                 viewer=None,
                                 res0=8,
                                 progress=False):

    mito_bodies_mito_seg = np.where(mito_bodies_mask & mito_binary, body_seg,
                                    0)
    nonmito_body_seg = np.where(mito_bodies_mask, 0, body_seg)

    hull_cc_overlap_stats = []
    for hull_cc, (mask_box, mask) in tqdm_proxy(hull_masks.items(),
                                                disable=not progress):
        mbms = mito_bodies_mito_seg[box_to_slicing(*mask_box)]
        masked_hull_cc_bodies = np.where(mask, mbms, 0)
        # Faster to check for any non-zero values at all before trying to count them.
        # This early check saves a lot of time in practice.
        if not masked_hull_cc_bodies.any():
            continue

        # This hull was generated from a particular seed body (non-mito body).
        # If it accidentally overlaps with any other non-mito bodies,
        # then delete those voxels from the hull.
        # If that causes the hull to become split apart into multiple connected components,
        # then keep only the component(s) which overlap the seed body.
        seed_body = seed_bodies[hull_cc]
        nmbs = nonmito_body_seg[box_to_slicing(*mask_box)]
        other_bodies = set(pd.unique(nmbs[mask])) - {0, seed_body}
        if other_bodies:
            # Keep only the voxels on mito bodies or on the
            # particular non-mito body for this hull (the "seed body").
            mbm = mito_bodies_mask[box_to_slicing(*mask_box)]
            mask[:] &= (mbm | (nmbs == seed_body))
            mask = vigra.taggedView(mask, 'zyx')
            mask_cc = vigra.analysis.labelMultiArrayWithBackground(
                mask.view(np.uint8))
            if mask_cc.max() > 1:
                mask_ct = contingency_table(mask_cc, nmbs).reset_index()
                keep_ccs = mask_ct['left'].loc[(mask_ct['left'] != 0) &
                                               (mask_ct['right'] == seed_body)]
                mask[:] = mask_for_labels(mask_cc, keep_ccs)

        mito_bodies, counts = np.unique(masked_hull_cc_bodies,
                                        return_counts=True)
        overlaps = pd.DataFrame({
            'mito_body': mito_bodies,
            'overlap': counts,
            'hull_cc': hull_cc,
            'hull_size': mask.sum(),
            'hull_body': seed_body
        })
        hull_cc_overlap_stats.append(overlaps)

    if len(hull_cc_overlap_stats) == 0:
        logger.warning("Could not find any matches for any mito bodies!")
        mito_body_ct['hull_body'] = np.uint64(0)
        return mito_body_ct

    hull_cc_overlap_stats = pd.concat(hull_cc_overlap_stats, ignore_index=True)
    hull_cc_overlap_stats = hull_cc_overlap_stats.query(
        'mito_body != 0').copy()

    # Aggregate the stats for each body and the hull bodies it overlaps with,
    # Select the hull_body with the most overlap, or in the case of ties, the hull body that is largest overall.
    # (Ties are probably more common in the event that two hulls completely encompass a small mito body.)
    hull_body_overlap_stats = hull_cc_overlap_stats.groupby(
        ['mito_body', 'hull_body'])[['overlap', 'hull_size']].sum()
    hull_body_overlap_stats = hull_body_overlap_stats.sort_values(
        ['mito_body', 'overlap', 'hull_size'], ascending=False)
    hull_body_overlap_stats = hull_body_overlap_stats.reset_index()

    mito_hull_selections = (hull_body_overlap_stats.drop_duplicates(
        'mito_body').set_index('mito_body')['hull_body'])
    mito_body_ct = mito_body_ct.merge(mito_hull_selections,
                                      'left',
                                      left_index=True,
                                      right_index=True)
    mito_body_ct['hull_body'] = mito_body_ct['hull_body'].fillna(0)

    dtypes = {col: np.float32 for col in mito_body_ct.columns}
    dtypes['hull_body'] = np.uint64
    mito_body_ct = mito_body_ct.astype(dtypes)

    if viewer:
        assert mito_hull_selections.index.dtype == mito_hull_selections.values.dtype == np.uint64
        mito_hull_mapper = LabelMapper(mito_hull_selections.index.values,
                                       mito_hull_selections.values)
        remapped_body_seg = mito_hull_mapper.apply(body_seg, True)
        remapped_body_seg = apply_mask_for_labels(remapped_body_seg,
                                                  mito_hull_selections.values)
        update_seg_layer(viewer, 'altered-bodies', remapped_body_seg, scale,
                         box)

        # Show the final hull masks (after erasure of non-target bodies)
        assert sorted(hull_masks.keys()) == [*range(1, 1 + len(hull_masks))]
        hull_cc_overlap_stats = hull_cc_overlap_stats.sort_values('hull_size')
        hull_seg = np.zeros_like(remapped_body_seg)
        for row in hull_cc_overlap_stats.itertuples():
            mask_box, mask = hull_masks[row.hull_cc]
            view = hull_seg[box_to_slicing(*mask_box)]
            view[:] = np.where(mask, row.hull_body, view)
        update_seg_layer(viewer, 'final-hull-seg', hull_seg, scale, box)

    return mito_body_ct
Exemple #14
0
    def execute(self):
        self._sanitize_config()
        self._prepare_output()

        input_config = self.config["input"]["dvid"]
        output_config = self.config["output"]
        options = self.config["svdecimate"]
        resource_config = self.config["resource-manager"]

        resource_mgr_client = ResourceManagerClient(resource_config["server"], resource_config["port"])

        server = input_config["server"]
        uuid = input_config["uuid"]
        tsv_instance = input_config["tarsupervoxels-instance"]

        bodies = load_body_list(options["bodies"], False)

        # Determine segmentation instance
        info = fetch_instance_info(server, uuid, tsv_instance)
        input_format = info["Extended"]["Extension"]

        output_format = options["format"]

        if np.array(options["rescale"] == 1.0).all() and output_format == "ngmesh" and input_format != "ngmesh":
            logger.warning("*** You are converting to ngmesh format, but you have not specified a rescale parameter! ***")

        decimation_lib = options["decimation-library"]
        max_sv_vertices = options["max-sv-vertices"]
        max_body_vertices = options["max-body-vertices"]
        num_procs = options["processes-per-body"]

        def process_body(body_id):
            with resource_mgr_client.access_context( input_config["server"], True, 1, 0 ):
                tar_bytes = fetch_tarfile(server, uuid, tsv_instance, body_id)

            sv_meshes = Mesh.from_tarfile(tar_bytes, concatenate=False)
            sv_meshes = {int(os.path.splitext(name)[0]): m for name, m in sv_meshes.items()}

            total_body_vertices = sum([len(m.vertices_zyx) for m in sv_meshes.values()])
            decimation = min(1.0, max_body_vertices / total_body_vertices)

            try:
                _process_sv = partial(process_sv, decimation, decimation_lib, max_sv_vertices, output_format)
                if num_procs <= 1:
                    output_table = [*starmap(_process_sv, sv_meshes.items())]
                else:
                    output_table = compute_parallel(_process_sv, sv_meshes.items(), starmap=True, processes=num_procs, ordered=False, show_progress=False)

                cols = ['sv', 'orig_vertices', 'final_vertices', 'final_decimation', 'effective_decimation', 'mesh_bytes']
                output_df = pd.DataFrame(output_table, columns=cols)
                output_df['body'] = body_id
                output_df['error'] = ""
                write_sv_meshes(output_df, output_config, output_format, resource_mgr_client)
            except Exception as ex:
                svs = [*sv_meshes.keys()]
                orig_vertices = [len(m.vertices_zyx) for m in sv_meshes.values()]
                output_df = pd.DataFrame({'sv': svs, 'orig_vertices': orig_vertices})
                output_df['final_vertices'] = -1
                output_df['final_decimation'] = -1
                output_df['effective_decimation'] = -1
                output_df['mesh_bytes'] = -1
                output_df['body'] = body_id
                output_df['error'] = str(ex)

            return output_df.drop(columns=['mesh_bytes'])

        futures = self.client.map(process_body, bodies)

        # Support synchronous testing with a fake 'as_completed' object
        if hasattr(self.client, 'DEBUG'):
            ac = as_completed_synchronous(futures, with_results=True)
        else:
            ac = distributed.as_completed(futures, with_results=True)

        try:
            stats = []
            for f, r in tqdm_proxy(ac, total=len(futures)):
                stats.append(r)
                if (r['error'] != "").any():
                    body = r['body'].iloc[0]
                    logger.warning(f"Body {body} failed!")

        finally:
            stats_df = pd.concat(stats)
            stats_df.to_csv('mesh-stats.csv', index=False, header=True)
            with open('mesh-stats.pkl', 'wb') as f:
                pickle.dump(stats_df, f)
Exemple #15
0
def load_roi_label_volume(server,
                          uuid,
                          rois_or_neuprint,
                          box_s5=[None, None],
                          export_path=None,
                          export_labelmap=None):
    """
    Fetch several ROIs from DVID and combine them into a single label volume or mask.
    The label values in the returned volume correspond to the order in which the ROI
    names were passed in, starting at label 1.
    
    This function is essentially a convenience function around fetch_combined_roi_volume(),
    but in this case it will optionally auto-fetch the ROI list, and auto-export the volume.
    
    Args:
        server:
            DVID server

        uuid:
            DVID uuid

        rois_or_neuprint:
            Either a list of ROIs or a neuprint server from which to obtain the roi list.

        box_s5:
            If you want to restrict the ROIs to a particular subregion,
            you may pass your own bounding box (at scale 5).
            Alternatively, you may pass the name of a segmentation
            instance from DVID whose bounding box will be used.

        export_path:
            If you want the ROI volume to be exported to disk,
            provide a path name ending with .npy or .h5.
        
        export_labelmap:
            If you want the ROI volume to be exported to a DVID labelmap instance,
            Provide the instance name, or a tuple of (server, uuid, instance).
    
    Returns:
        (roi_vol, roi_box), containing the fetched label volume and the
        bounding box it corresponds to, in DVID scale-5 coordinates.

    Note:
      If you have a list of (full-res) points to extract from the returned volume,
      pass a DataFrame with columns ['z','y','x'] to the following function.
      If you already downloaded the roi_vol (above), provide it.
      Otherwise, leave out those args and it will be fetched first.
      Adds columns to the input DF (in-place) for 'roi' (str) and 'roi_label' (int).
    
        >>> from neuclease.dvid import determine_point_rois
        >>> determine_point_rois(*master, rois, point_df, roi_vol, roi_box)
    """
    if isinstance(box_s5, str):
        # Assume that this is a segmentation instance whose dimensions should be used
        # Fetch the maximum extents of the segmentation,
        # and rescale it for scale-5.
        seg_box = fetch_volume_box(server, uuid, box_s5)
        box_s5 = round_box(seg_box, (2**5), 'out') // 2**5
        box_s5[0] = (0, 0, 0)

    if export_labelmap:
        assert isinstance(box_s5, np.ndarray)
        assert not (box_s5 % 64).any(), \
            ("If exporting to a labelmap instance, please supply "
             "an explicit box and make sure it is block-aligned.")

    if isinstance(rois_or_neuprint, (str, neuprint.Client)):
        if isinstance(rois_or_neuprint, str):
            npclient = neuprint.Client(rois_or_neuprint)
        else:
            npclient = rois_or_neuprint

        # Fetch ROI names from neuprint
        q = "MATCH (m: Meta) RETURN m.superLevelRois as rois"
        rois = npclient.fetch_custom(q)['rois'].iloc[0]
        rois = sorted(rois)
        # # Remove '.*ACA' ROIs. Apparently there is some
        # # problem with them. (They overlap with other ROIs.)
        # rois = [*filter(lambda r: 'ACA' not in r, rois)]
    else:
        assert isinstance(rois_or_neuprint, collections.abc.Iterable)
        rois = rois_or_neuprint

    # Fetch each ROI and write it into a volume
    with Timer(f"Fetching combined ROI volume for {len(rois)} ROIs", logger):
        roi_vol, roi_box, overlap_stats = fetch_combined_roi_volume(
            server, uuid, rois, box_zyx=box_s5)

    if len(overlap_stats) > 0:
        logger.warn(
            f"Some ROIs overlap! Here's an incomplete list of overlapping pairs:\n{overlap_stats}"
        )

    # Export to npy/h5py for external use
    if export_path:
        with Timer(f"Exporting to {export_path}", logger):
            if export_path.endswith('.npy'):
                np.save(export_path, roi_vol)
            elif export_path.endswith('.h5'):
                with h5py.File(export_path, 'w') as f:
                    f.create_dataset('rois_scale_5', data=roi_vol, chunks=True)

    if export_labelmap:
        if isinstance(export_labelmap, str):
            export_labelmap = (server, uuid, export_labelmap)

        assert len(export_labelmap) == 3
        with Timer(f"Exporting to {export_labelmap[2]}", logger):
            if export_labelmap[2] not in fetch_repo_instances(
                    server, uuid, 'labelmap'):
                create_labelmap_instance(
                    *export_labelmap, voxel_size=8 * (2**5),
                    max_scale=6)  # FIXME: hard-coded voxel size

            # It's really important to use this block shape.
            # See https://github.com/janelia-flyem/dvid/issues/342
            boxes = boxes_from_grid(roi_box, (256, 256, 256), clipped=True)
            for box in tqdm_proxy(boxes):
                block = extract_subvol(roi_vol, box - roi_box[0])
                post_labelmap_voxels(*export_labelmap,
                                     box[0],
                                     block,
                                     scale=0,
                                     downres=True)

    return roi_vol, roi_box, rois
Exemple #16
0
def audit_traced_history_for_branch(server,
                                    branch='',
                                    instance='segmentation',
                                    start_uuid=None,
                                    final_uuid=None):
    """
    Audits a labelmap kafka log and dvid history to
    check for the following spurious events:
    
        Merges:
            - A traced body should never be merged into another
              non-traced body and lose its body id
            - Two traced bodies that are merged together should be flagged — this
              might happen if a Leaves is merged to a Roughly traced
        
        Cleaves/body-splits:
            - A traced body should never be split where the bigger
              piece (or close to half the size) is split off

    Note:
        This function does not yet handle split events,
        which leads to certain cases that will be missed here.
        Furthermore, the body sizes in the results are not guaranteed to be accurate.
    """
    if start_uuid is not None:
        start_uuid = expand_uuid(server, start_uuid)

    if final_uuid is not None:
        final_uuid = expand_uuid(server, final_uuid)

    branch_uuids = find_branch_nodes(server, final_uuid, branch)
    leaf_uuid = branch_uuids[-1]

    if start_uuid is None:
        # Start with the second uuid by default
        start_uuid = branch_uuids[1]
    else:
        assert start_uuid != branch_uuids[0], \
            "Can't start from the root uuid, since the size/status information before that is unknown."

    if final_uuid is None:
        final_uuid = branch_uuids[-1]

    start_uuid_index = branch_uuids.index(start_uuid)
    final_uuid_index = branch_uuids.index(final_uuid)
    prev_uuid = branch_uuids[start_uuid_index - 1]

    audit_uuids = branch_uuids[start_uuid_index:1 + final_uuid_index]
    msgs_df = read_labelmap_kafka_df(server,
                                     final_uuid,
                                     instance,
                                     drop_completes=True)

    ann_kv_log = read_kafka_messages(server, final_uuid,
                                     f'{instance}_annotations')
    ann_kv_log_df = kafka_msgs_to_df(ann_kv_log)
    ann_kv_log_df['body'] = ann_kv_log_df['key'].astype(np.uint64)

    ann_df = fetch_body_annotations(server, prev_uuid,
                                    f'{instance}_annotations')[['status']]

    split_events = fetch_supervoxel_splits(server, leaf_uuid, instance, 'dvid',
                                           'dict')

    bad_merge_events = []
    bad_cleave_events = []

    body_sizes = {}
    for cur_uuid in tqdm_proxy(audit_uuids):

        def get_body_size(body):
            if body in body_sizes:
                return body_sizes[body]
            try:
                body_sizes[body] = fetch_size(server, prev_uuid, instance,
                                              body)
                return body_sizes[body]
            except Exception:
                return 0

        logger.info(f"Auditing uuid '{cur_uuid}'")
        cur_msgs_df = msgs_df.query('uuid == @cur_uuid')
        cur_bad_merge_events = []
        cur_bad_cleave_events = []
        for row in tqdm_proxy(cur_msgs_df.itertuples(index=False),
                              total=len(cur_msgs_df),
                              leave=False):
            if row.target_body in ann_df.index:
                target_status = ann_df.loc[row.target_body, 'status']
            else:
                target_status = None

            # Check for merges that eliminate a traced body
            if row.action == "merge":
                target_body_size = get_body_size(row.target_body)
                # Check for traced bodies being merged INTO another
                _merged_labels = row.msg['Labels']
                bad_merges_df = ann_df.query(
                    'body in @_merged_labels and status in @TRACED_STATUSES')
                if len(bad_merges_df) > 0:
                    if target_status in TRACED_STATUSES:
                        cur_bad_merge_events.append(
                            ('merge-traced-to-traced', *row, target_status,
                             bad_merges_df.index.tolist(),
                             bad_merges_df['status'].tolist()))
                    else:
                        cur_bad_merge_events.append(
                            ('merge-traced-to-nontraced', *row, target_status,
                             bad_merges_df.index.tolist(),
                             bad_merges_df['status'].tolist()))

                # Update local body_sizes
                merged_size = sum([
                    get_body_size(merge_label)
                    for merge_label in row.msg['Labels']
                ])
                if target_body_size != 0:
                    body_sizes[row.target_body] += merged_size

            # Check for bodies that lost more than half of their size in a cleave
            if row.action == "cleave" and target_status in TRACED_STATUSES:
                target_body_size = get_body_size(row.target_body)
                if target_body_size == 0:
                    # Since we aren't assessing split events yet,
                    # it's possible that we don't have all the information we need to assess all bodies.
                    # Bodies that were created during splits are not tracked.
                    cur_bad_cleave_events.append(
                        ('failed-to-assess', *row, target_body_size, 0))
                else:
                    cleaved_sizes = fetch_sizes(server,
                                                cur_uuid,
                                                instance,
                                                row.msg['CleavedSupervoxels'],
                                                supervoxels=True)
                    for sv in tqdm_proxy(
                            cleaved_sizes[(cleaved_sizes == 0)].index,
                            leave=False):
                        cleaved_sizes.loc[sv] = fetch_retired_supervoxel_size(
                            server, leaf_uuid, instance, sv, split_events)

                    if cleaved_sizes.sum() > target_body_size / 2:
                        cur_bad_cleave_events.append(
                            ('large-cleave', *row, target_body_size,
                             cleaved_sizes.sum()))

                    # Update local body_sizes
                    body_sizes[row.target_body] -= cleaved_sizes.sum()
                    body_sizes[row.msg['CleavedLabel']] = cleaved_sizes.sum()

            # TODO: Check body split events, too.
            # We could apply the above message effects to ann_df,
            # but it doesn't actually matter.

        logger.info(
            f"Found {len(cur_bad_merge_events)} bad merges and {len(cur_bad_cleave_events)} bad cleaves"
        )
        bad_merge_events += cur_bad_merge_events
        bad_cleave_events += cur_bad_cleave_events

        # Rather than fetching the entire ann_df for the next uuid,
        # just update the keys that changed (according to kafka).

        # TODO: It would be interesting to compare the difference between
        #       statuses that we computed vs. the statuses in the new uuid
        updated_bodies = ann_kv_log_df.query('uuid == @cur_uuid')['body']
        ann_update_df = fetch_body_annotations(server, cur_uuid,
                                               f'{instance}_annotations',
                                               updated_bodies)[['status']]
        ann_df = pd.concat((ann_df, ann_update_df))
        ann_df = ann_df.loc[~(ann_df.index.duplicated(keep='last'))]

    bad_merges_df = pd.DataFrame(bad_merge_events,
                                 columns=[
                                     'reason', *msgs_df.columns,
                                     'target_status', 'traced_bodies',
                                     'traced_statuses'
                                 ])
    bad_cleaves_df = pd.DataFrame(
        bad_cleave_events,
        columns=['reason', *msgs_df.columns, 'target_size', 'cleave_size'])

    np.save('audit_bad_merges_df.npy', bad_merges_df.to_records(index=False))
    np.save('audit_bad_cleaves_df.npy', bad_cleaves_df.to_records(index=False))

    return bad_merges_df, bad_cleaves_df
Exemple #17
0
    def _initialize(self):
        # Compute Y-midpoint of each block at every level
        group_midpoints = []
        for level in range(self.num_levels):
            midpoints = {}
            for block, group_df in self.partition_df.groupby(level,
                                                             sort=False)[[]]:
                midpoints[block] = (group_df.index.min() +
                                    group_df.index.max()) / 2
            group_midpoints.append(midpoints)

        # Construct line segment endpoints that will draw the tree hierarchy
        all_group_modes = {}

        # At the bottom level, every node is its own group, so the 'mode' is trivial.
        all_group_modes[0] = self.partition_df[[0, 'type'
                                                ]].rename(columns={
                                                    0: 'node',
                                                    'type': 'type_mode'
                                                })
        all_group_modes[0]['level'] = 0

        tree_line_segments = []
        for level in tqdm_proxy(range(1, self.num_levels)):
            #lower_group_points = []
            for lower_node, node in enumerate(self.nbs.get_bs()[level - 1]):
                left_x = level - 1
                right_x = level
                left_y = group_midpoints[level - 1][lower_node]
                right_y = group_midpoints[level][node]

                tree_line_segments.append([
                    level - 1, lower_node, node, left_x, left_y, right_x,
                    right_y
                ])

            # Drop null types before finding most common,
            # so that we find "most common non-null type"
            modes = (self.partition_df[[
                'type', level
            ]].query("type != ''").groupby(level)['type'].apply(
                lambda s: s.mode().iloc[0]).rename('type_mode'))

            # Since type-less nodes were dropped above, nodes which are empty "all the way down" are missing.
            # Add them back in, with an empty string.
            modes = (self.partition_df[[level]].drop_duplicates().merge(
                modes, 'left', left_on=level,
                right_index=True).fillna('').set_index(level)['type_mode'])

            all_group_modes[level] = modes
            all_group_modes[level].index.name = 'node'
            all_group_modes[level] = all_group_modes[level].reset_index()
            all_group_modes[level]['level'] = level

        # Add a 0-length line segment for the root node
        root_x = self.num_levels - 1
        root_y = tree_line_segments[-1][-1]
        tree_line_segments.append(
            [self.num_levels - 1, 0, -1, root_x, root_y, root_x, root_y])
        tree_line_segments_df = pd.DataFrame(tree_line_segments,
                                             columns=[
                                                 'level', 'node', 'parent',
                                                 'x', 'y', 'parent_x',
                                                 'parent_y'
                                             ])

        all_stats_df = pd.concat(all_group_modes.values())
        all_stats_df = all_stats_df.merge(tree_line_segments_df,
                                          'left',
                                          on=['level', 'node'])
        all_stats_df['color'] = 'gray'
        all_stats_df = all_stats_df[[
            'level', 'node', 'parent', 'x', 'y', 'parent_x', 'parent_y',
            'type_mode', 'color'
        ]]

        # Initialize strengths with 0 width lines (will be updated dynamically)
        strengths_df = all_stats_df.query('level == 0')[['node', 'y']].copy()
        strengths_df = (self.partition_df[[0, 'body', 'type',
                                           'instance']].rename(columns={
                                               0: 'node'
                                           }).merge(strengths_df,
                                                    'left',
                                                    on=['node']))

        in_strengths_df = strengths_df.copy()
        out_strengths_df = strengths_df.copy()
        del strengths_df

        for df, color in [(in_strengths_df, 'red'),
                          (out_strengths_df, 'green')]:
            df['color'] = color
            df['visible'] = False
            df['left'] = 0
            df['right'] = 0
            df['weight'] = 0
            df['height'] = 1.0
            df['rois'] = ''

        # Save members
        self.group_midpoints = group_midpoints
        self.all_stats_df = all_stats_df
        self.strengths_df = out_strengths_df
        self.in_strengths_df = in_strengths_df
        self.out_strengths_df = out_strengths_df
def main():
    configure_default_logging()
    
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('--fraction', type=float,
                        help='Fraction of vertices to retain in the decimated mesh.  Between 0.0 and 1.0')
    parser.add_argument('--max-vertices', type=float, default=1e9,
                        help='If necessary, decimate the mesh even further so that it has no more than this vertex count (approximately).')
    parser.add_argument('--format',
                        help='Either obj or drc', required=True)
    parser.add_argument('--rescale', type=float,
                        help='Multiply all vertex coordinates by this factor before storing the mesh. Important for writing to ngmesh format.')
    parser.add_argument('--output-directory', '-d',
                        help='Directory to dump decimated meshes.')
    parser.add_argument('--output-url', '-u',
                        help='DVID keyvalue instance to write decimated mesh files to, '
                        'specified as a complete URL, e.g. http://emdata1:8000/api/node/123abc/my-meshes')
    parser.add_argument('server', help='dvid server, e.g. emdata3:8900')
    parser.add_argument('uuid', help='dvid node')
    parser.add_argument('tsv_instance', help='name of a tarsupervoxels instance, e.g. segmentation_sv_meshes')    
    parser.add_argument('bodies', nargs='+',
                        help='A list of body IDs OR a path to a CSV containing a column named "body", which will be read.\n'
                             'If no "body" column exists, the first column is used, regardless of the name.')

    args = parser.parse_args()

    if args.fraction is None:
        raise RuntimeError("Please specify a decimation fraction.")

    if args.format is None:
        raise RuntimeError("Please specify an output format (either 'drc' or 'obj' via --format")

    if args.output_directory:
        os.makedirs(args.output_directory, exist_ok=True)

    if args.format == "ngmesh" and args.rescale is None:
        raise RuntimeError("When writing to ngmesh, please specify an explict rescale factor.")

    args.rescale = args.rescale or 1.0

    output_dvid = None    
    if args.output_url:
        if '/api/node' not in args.output_url:
            raise RuntimeError("Please specify the output instance as a complete URL, "
                               "e.g. http://emdata1:8000/api/node/123abc/my-meshes")
        
        # drop 'http://' (if present)
        url = args.output_url.split('://')[-1]
        parts = url.split('/')
        assert parts[1] == 'api'
        assert parts[2] == 'node'
        
        output_server = parts[0]
        output_uuid = parts[3]
        output_instance = parts[4]
        
        output_dvid = (output_server, output_uuid, output_instance)


    all_bodies = []
    for body in args.bodies:
        if body.endswith('.csv'):
            if 'body' in read_csv_header(body):
                bodies = pd.read_csv(body)['body'].drop_duplicates()
            else:
                # Just read the first column, no matter what it's named
                bodies = read_csv_col(body, 0, np.uint64).drop_duplicates()
        else:
            try:
                body = int(body)
            except ValueError:
                raise RuntimeError(f"Invalid body ID: '{body}'")
        
        all_bodies.extend(bodies)

    for body_id in tqdm_proxy(all_bodies):
        output_path = None
        if args.output_directory:
            output_path = f'{args.output_directory}/{body_id}.{args.format}'

        decimate_existing_mesh(args.server, args.uuid, args.tsv_instance, body_id, args.fraction, args.max_vertices, args.rescale, args.format, output_path, output_dvid)
Exemple #19
0
def copy_splits_exact(src_server,
                      src_uuid,
                      src_instance,
                      dest_server,
                      dest_uuid,
                      dest_instance,
                      kafka_msgs,
                      min_timestamp=None,
                      max_timestamp=None,
                      min_mutid=None,
                      max_mutid=None,
                      pause_between_splits=0.0):
    src_seg = (src_server, src_uuid, src_instance)
    dest_seg = (dest_server, dest_uuid, dest_instance)

    min_timestamp = parse_timestamp(min_timestamp)
    max_timestamp = parse_timestamp(max_timestamp)

    kafka_msgs = filter_kafka_msgs_by_timerange(kafka_msgs, min_timestamp,
                                                max_timestamp, min_mutid,
                                                max_mutid)

    split_events = fetch_supervoxel_splits_from_kafka(*src_seg,
                                                      kafka_msgs=kafka_msgs)
    split_df = split_events_to_dataframe(split_events)

    # We assume that supervoxel ID values are toposorted, so sorting by
    # new 'split' ID is sufficient to ensure in-order splits.
    # (supervoxel splits already appear in the log in-order, but arbitrary splits
    # contain a batch of splits with identical mutation IDs, whose split IDs do
    # not necessarily appear in sorted order.)
    split_df.sort_values('split', inplace=True)

    split_forest = split_events_to_graph(split_events)

    def get_combined_leaf_sparsevol(sv):
        descendents = nx.descendants(split_forest, sv)
        descendents.add(sv)
        leaves = list(
            filter(lambda d: split_forest.out_degree(d) == 0, descendents))
        combined_sparsevol = fetch_and_combine_sparsevols(*src_seg,
                                                          leaves,
                                                          supervoxels=True)
        return leaves, combined_sparsevol

    for row in tqdm_proxy(split_df.itertuples(index=False),
                          total=len(split_df)):
        logger.info(f"Fetching sparsevols for leaves of {row.split}")
        split_leaves, split_payload = get_combined_leaf_sparsevol(row.split)
        size, coord_zyx = extract_rle_size_and_first_coord(split_payload)

        logger.info(
            f"Posting mutation {row.mutid}: {size}-voxel split of {row.old} into {row.split} and {row.remain}, from sparsevols of {split_leaves}"
        )

        # Check the destination -- is it the supervoxel we expected to split?
        dest_sv = fetch_label_for_coordinate(*dest_seg,
                                             coord_zyx,
                                             supervoxels=True)
        if dest_sv != row.old:
            raise RuntimeError(
                f"Unexpected supervoxel at the destination: Expected {row.old}, found {dest_sv}"
            )

        post_split_supervoxel(*dest_seg,
                              row.old,
                              split_payload,
                              split_id=row.split,
                              remain_id=row.remain)
        time.sleep(pause_between_splits)

    logger.info("DONE.")
def write_point_neighborhoods(seg_src,
                              seg_dst,
                              points_zyx,
                              radius=125,
                              src_bodies=None,
                              dst_bodies=None):
    """
    For each point in the given list, create a mask of the portion
    of a particular body that falls within a given distance of the
    point.

    Args:
        seg_src:
            tuple (server, uuid, instance) specifying where to fetch neuron segmentation from.
        seg_dst:
            tuple (server, uuid, instance) specifying where to write neighborhood segmentation to.
        points_zyx:
            Array of coordinates to create neighborhoods around
        radius:
            Radius (in voxels) of the neighborhood to create.
            Hint: In the hemibrain, 1 micron = 125 voxels.
        src_bodies:
            Either a single body ID or a list of body IDs (corresponding to the list of points_zyx).
            Specifies which body the neighborhood around each point should be constructed for.
            If not provided, the body for each neighborhood will be chosen automatically,
            by determining which body each point in points_zyx falls within.
        dst_bodies:
            List of new body IDs.
            Specifies the IDs to use as the 'body ID' for the neighborhood segments when writing to
            the destination instance.  If no list is given, then new body IDs are automatically
            generated with a formula that uses the coordinate around which the neighborhood was created.
            Note that the default formula does not take the source body into account,
            so if there are duplicate points provided in points_zyx, the destination body IDs will
            be duplicated too, unless you supply your own destination body IDs here.

    Returns:
        In addition to writing the neighborhood segments to the seg_dst instance,
        this function returns a dataframe with basic stats about the neighborhoods
        that were written.
    """
    if isinstance(points_zyx, pd.DataFrame):
        points_zyx = points_zyx[[*'zyx']].values
    else:
        points_zyx = np.asarray(points_zyx)

    results = []
    for i, point in enumerate(tqdm_proxy(points_zyx)):
        if isinstance(src_bodies, Iterable):
            src_body = src_bodies[i]
        else:
            src_body = src_bodies

        if isinstance(dst_bodies, Iterable):
            dst_body = dst_bodies[i]
        else:
            dst_body = dst_bodies

        point, centroid, top_point, src_body, dst_body, dst_voxels = \
            process_point(seg_src, seg_dst, point, radius, src_body, dst_body)

        results.append(
            (*point, *centroid, *top_point, src_body, dst_body, dst_voxels))

    cols = ['z', 'y', 'x']
    cols += ['cz', 'cy', 'cx']
    cols += ['tz', 'ty', 'tx']
    cols += ['src_body', 'dst_body', 'dst_voxels']
    return pd.DataFrame(results, columns=cols)
    # Fetch tbars
    if tbars is None:
        tbars = fetch_synapses(body,
                               SC(type='pre', primary_only=True),
                               client=npclient)
    else:
        tbars = tbars.copy()

    tbars['body'] = body
    tbars['mito-distance'] = np.inf
    tbars['done'] = False
    tbars['mito-x'] = 0
    tbars['mito-y'] = 0
    tbars['mito-z'] = 0

    with tqdm_proxy(total=len(tbars)) as progress:
        for row in tbars.itertuples():
            if row.done:
                continue

            for radius_s0, scale in search_configs:
                num_done = _measure_tbar_mito_distances(
                    seg_src, mito_src, body, tbars, row.Index, radius_s0,
                    scale, mito_min_size_s0, mito_scale_offset)
                progress.update(num_done)
                done = (tbars['done'].loc[row.Index])
                if done:
                    break
                logger.info(
                    "Search failed for primary tbar. Trying next search config!"
                )
def construct_mr_endpoint_df(mr_fragments_df, bois):
    """
    Each merge-review task group contains exactly two endpoint nodes.
    Locate each endpoint pair and return it in a DataFrame
    (with _a/_b columns as if it were an ordinary task dataframe).
    """
    assert isinstance(bois, set)
    mr_fragments_df = mr_fragments_df.copy()

    # Make sure our two BOI endpoints are always in body_a
    #_a_is_small = mr_fragments_df.eval('label_a not in @bois')

    # This is much faster than the above .eval() call if bois is large.
    _a_is_small = [(label_a not in bois)
                   for label_a in mr_fragments_df['label_a']]
    _a_is_small = pd.Series(_a_is_small, index=mr_fragments_df.index)
    swap_df_cols(mr_fragments_df, None, _a_is_small, ('a', 'b'))

    # edge_area ends in 'a', which is inconvenient
    # for the column selection below,
    # and we don't need it anyway.  Drop it.
    fmr_df = mr_fragments_df.drop(columns=['edge_area'])

    # All columns ending in 'a'.
    cols_a = [col for col in fmr_df.columns if col.endswith('a')]

    num_tasks = len(mr_fragments_df.drop_duplicates(['group_cc', 'cc_task']))
    post_col = fmr_df[cols_a].columns.tolist().index('PostSyn_a')

    filtered_mr_endpoints = []
    for (group_cc, cc_task), task_df in tqdm_proxy(fmr_df.groupby(
        ['group_cc', 'cc_task']),
                                                   total=num_tasks):
        #assert task_df.eval('label_b not in @_bois').all()

        # Find the two rows that mention a BOI
        #selected_df = task_df.query('label_a in @_bois')

        # Apparently this is MUCH faster than .query() when bois is large
        _a_is_big = [(label_a in bois) for label_a in task_df['label_a']]
        selected_df = task_df.iloc[_a_is_big]

        selected_df = selected_df[cols_a]
        assert len(selected_df) == 2

        stats_a, stats_b = list(selected_df.itertuples(index=False))

        # Put the big body in the 'a' position.
        if stats_a[post_col] < stats_b[post_col]:
            stats_a, stats_b = stats_b, stats_a

        filtered_mr_endpoints.append(
            (group_cc, cc_task, len(task_df), *stats_a, *stats_b))

    cols_b = [col[:-1] + 'b' for col in cols_a]
    combined_cols = ['group_cc', 'cc_task', 'num_edges', *cols_a, *cols_b]
    mr_endpoints_df = pd.DataFrame(filtered_mr_endpoints,
                                   columns=combined_cols)

    final_cols = [
        'group_cc', 'cc_task', 'num_edges', *sorted(combined_cols[3:])
    ]
    mr_endpoints_df = mr_endpoints_df[final_cols]

    return mr_endpoints_df
def generate_mergereview_assignments_from_df(server,
                                             uuid,
                                             instance,
                                             mr_fragments_df,
                                             bois,
                                             assignment_size,
                                             output_dir,
                                             single_file=False):
    """
    Generate a set of assignments for the given mergereview fragments.
    The assignments are written to a nested hierarchy:
    Grouped first by task size (number of bodies in each task),
    and then grouped in batches of N tasks (assignment_size).
    
    The body IDs emitted in the assignments and their classification as "BOI"
    or not is determined by fetching the mappings for each supervoxel in the dataframe.
    """
    # Sort table by task size (edge count)
    group_sizes = mr_fragments_df.groupby(['group_cc', 'cc_task'
                                           ]).size().rename('group_size')
    mr_fragments_df = mr_fragments_df.merge(group_sizes,
                                            'left',
                                            left_on=['group_cc', 'cc_task'],
                                            right_index=True)
    mr_fragments_df = mr_fragments_df.sort_values(
        ['group_size', 'group_cc', 'cc_task'])

    mr_fragments_df['body_a'] = fetch_mapping(server, uuid, instance,
                                              mr_fragments_df['sv_a'])
    mr_fragments_df['body_b'] = fetch_mapping(server, uuid, instance,
                                              mr_fragments_df['sv_b'])

    mr_fragments_df['is_boi_a'] = mr_fragments_df.eval('body_a in @bois')
    mr_fragments_df['is_boi_b'] = mr_fragments_df.eval('body_b in @bois')

    # Group assignments by task size and emit an assignment for each group
    all_tasks = {}
    for group_size, same_size_tasks_df in mr_fragments_df.groupby(
            'group_size'):
        group_tasks = []
        for (group_cc, cc_task), task_df in same_size_tasks_df.groupby(
            ['group_cc', 'cc_task']):
            svs = pd.unique(task_df[['sv_a', 'sv_b']].values.reshape(-1))
            svs = np.sort(svs)

            boi_svs = set(task_df[task_df['is_boi_a']]['sv_a'].tolist())
            boi_svs |= set(task_df[task_df['is_boi_b']]['sv_b'].tolist())

            task_bodies = pd.unique(task_df[['body_a', 'body_b'
                                             ]].values.reshape(-1)).tolist()

            task = {
                # neu3 fields
                'task type': "merge review",
                'task id': hex(zlib.crc32(svs)),
                'supervoxel IDs': svs.tolist(),
                'boi supervoxel IDs': sorted(boi_svs),

                # Encode edge table as json
                "supervoxel IDs A": task_df['sv_a'].tolist(),
                "supervoxel IDs B": task_df['sv_b'].tolist(),
                "supervoxel points A": task_df[['xa', 'ya',
                                                'za']].values.tolist(),
                "supervoxel points B": task_df[['xb', 'yb',
                                                'zb']].values.tolist(),

                # Debugging fields
                'group_cc': int(group_cc),
                'cc_task': int(cc_task),
                'original_bodies': sorted(task_bodies),
                'total_body_count': len(task_bodies),
                'original_uuid': uuid,
            }
            group_tasks.append(task)

        num_bodies = group_size + 1
        all_tasks[num_bodies] = group_tasks

    if single_file:
        # In single-file mode, the 'output_dir' is interpreted as the assignment path
        assert output_dir.endswith('.json')
        output_path = output_dir
        assignment = {
            "file type": "Neu3 task list",
            "file version": 1,
            "task list": list(chain(*all_tasks.values()))
        }
        with open(output_path, 'w') as f:
            #json.dump(assignment, f, indent=2)
            pretty_print_assignment_json_items(assignment.items(), f)
    else:
        # Now that the task json data has been generated and split into groups (by body count),
        # write them into multiple directories (one per group), each of which has muliple files
        # (one per task batch, as specified by assignment_size)
        for num_bodies, group_tasks in all_tasks.items():
            output_subdir = f'{output_dir}/{num_bodies:02}-bodies'
            os.makedirs(output_subdir, exist_ok=True)
            for i, batch_start in enumerate(
                    tqdm_proxy(range(0, len(group_tasks), assignment_size),
                               leave=False)):
                output_path = f"{output_dir}/{num_bodies:02}-bodies/assignment-{i:04d}.json"

                batch_tasks = group_tasks[batch_start:batch_start +
                                          assignment_size]
                assignment = {
                    "file type": "Neu3 task list",
                    "file version": 1,
                    "task list": batch_tasks
                }

                with open(output_path, 'w') as f:
                    #json.dump(assignment, f, indent=2)
                    pretty_print_assignment_json_items(assignment.items(), f)

    return all_tasks