Exemple #1
0
def main():
    # Create the destination instance if necessary.
    dst_instances = fetch_repo_instances(*dst_node, 'annotation')
    if dst_syn not in dst_instances:
        logger.info(f"Creating instance '{dst_syn}'")
        create_instance(*dst_node, dst_syn, 'annotation')

    # Check to see if the sync already exists; add it if necessary
    syn_info = fetch_instance_info(*dst_node, dst_syn)
    if len(syn_info["Base"]["Syncs"]) == 0:
        logger.info(f"Adding a sync to '{dst_syn}' from '{dst_seg}'")
        post_sync(*dst_node, dst_syn, [dst_seg])
    elif syn_info["Base"]["Syncs"][0] != dst_seg:
        other_seg = syn_info["Base"]["Syncs"][0]
        raise RuntimeError(
            f"Can't create a sync to '{dst_seg}'. "
            f"Your instance is already sync'd to a different segmentation: {other_seg}"
        )

    # Fetch segmentation extents
    bounding_box_zyx = fetch_volume_box(*src_node, src_seg).tolist()

    # Break into block-aligned chunks (boxes) that are long in the X direction
    # (optimal access pattern for dvid read/write)
    boxes = boxes_from_grid(bounding_box_zyx, (256, 256, 6400), clipped=True)

    # Use a process pool to copy the chunks in parallel.
    compute_parallel(copy_syn_blocks,
                     boxes,
                     processes=PROCESSES,
                     ordered=False)
Exemple #2
0
def main():
    # Hard-coded parameters
    prod = 'emdata4:8900'
    master = (prod, find_master(prod))
    master_seg = (*master, 'segmentation')

    # I accidentally corrupted the labelindex of bodies in this region
    patch_box = 20480 + np.array([[0, 0, 0], [1024, 1024, 1024]])

    with Timer("Fetching supervoxels", logger):
        boxes = boxes_from_grid(patch_box, Grid((64, 64, 6400)), clipped=True)
        sv_sets = compute_parallel(partial(_fetch_svs, master_seg),
                                   boxes,
                                   processes=32,
                                   ordered=False,
                                   leave_progress=True)
        svs = set(chain(*sv_sets)) - set([0])

    bodies = set(fetch_mapping(*master_seg, svs))

    with Timer(f"Repairing {len(bodies)} labelindexes", logger):
        compute_parallel(partial(_repair_index, master_seg),
                         bodies,
                         processes=32,
                         ordered=False,
                         leave_progress=True)

    print("DONE.")
Exemple #3
0
def test_compute_parallel():
    items = list(range(100))
    results = compute_parallel(_double, items, threads=2)
    assert results == list(range(0, 200, 2))

    items = list(range(100))
    results = compute_parallel(_double, items, processes=2)
    assert results == list(range(0, 200, 2))

    items = [*zip(range(10), range(100, 110))]
    results = compute_parallel(_add, items, processes=2, starmap=True)
    assert results == [sum(item) for item in items]
Exemple #4
0
        def process_body(body_id):
            with resource_mgr_client.access_context( input_config["server"], True, 1, 0 ):
                tar_bytes = fetch_tarfile(server, uuid, tsv_instance, body_id)

            sv_meshes = Mesh.from_tarfile(tar_bytes, concatenate=False)
            sv_meshes = {int(os.path.splitext(name)[0]): m for name, m in sv_meshes.items()}

            total_body_vertices = sum([len(m.vertices_zyx) for m in sv_meshes.values()])
            decimation = min(1.0, max_body_vertices / total_body_vertices)

            try:
                _process_sv = partial(process_sv, decimation, decimation_lib, max_sv_vertices, output_format)
                if num_procs <= 1:
                    output_table = [*starmap(_process_sv, sv_meshes.items())]
                else:
                    output_table = compute_parallel(_process_sv, sv_meshes.items(), starmap=True, processes=num_procs, ordered=False, show_progress=False)

                cols = ['sv', 'orig_vertices', 'final_vertices', 'final_decimation', 'effective_decimation', 'mesh_bytes']
                output_df = pd.DataFrame(output_table, columns=cols)
                output_df['body'] = body_id
                output_df['error'] = ""
                write_sv_meshes(output_df, output_config, output_format, resource_mgr_client)
            except Exception as ex:
                svs = [*sv_meshes.keys()]
                orig_vertices = [len(m.vertices_zyx) for m in sv_meshes.values()]
                output_df = pd.DataFrame({'sv': svs, 'orig_vertices': orig_vertices})
                output_df['final_vertices'] = -1
                output_df['final_decimation'] = -1
                output_df['effective_decimation'] = -1
                output_df['mesh_bytes'] = -1
                output_df['body'] = body_id
                output_df['error'] = str(ex)

            return output_df.drop(columns=['mesh_bytes'])
def remove_dead_annotations(server, uuid):
    ann = fetch_body_annotations(server, uuid)

    exists = compute_parallel(partial(_does_exist, server, uuid),
                              ann.index.values,
                              processes=16)
    exists = pd.DataFrame(exists, columns=['body', 'exists'])

    to_del = exists.query('not exists')['body']
    print(f"deleting {len(to_del)} dead annotations: {to_del.tolist()}")
    for body in to_del.values:
        delete_key(server, uuid, 'segmentation_annotations', str(body))
def main():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('-n', type=int, default=8)
    parser.add_argument('src')
    parser.add_argument('dest')
    args = parser.parse_args()

    args.src = abspath(args.src)
    args.dest = abspath(args.dest)

    if exists(args.dest):
        sys.exit(f"Error: Destination already exists: {args.dest}")

    from neuclease.util import compute_parallel, tqdm_proxy
    from neuclease import configure_default_logging
    configure_default_logging()

    os.chdir(args.src)
    logger.info("Listing source files")
    r = subprocess.run('find . -type f', shell=True, capture_output=True)
    src_paths = r.stdout.decode('utf-8').strip().split('\n')
    dest_paths = [f'{args.dest}/{p}' for p in src_paths]
    dest_dirs = sorted(set([*map(dirname, dest_paths)]))

    logger.info("Initializing directory tree")
    for d in tqdm_proxy(dest_dirs):
        os.makedirs(d, exist_ok=True)

    logger.info(f"Copying {len(src_paths)} files")
    compute_parallel(copyfile, [*zip(src_paths, dest_paths)],
                     10,
                     starmap=True,
                     ordered=False,
                     processes=args.n)

    logger.info("DONE")
def copy_synapses(src_loc, dst_loc, processes):
    """
    See caveats in the module docstring above.
    """
    src_loc = Location(*src_loc)
    dst_loc = Location(*dst_loc)

    # Create the destination instance if necessary.
    dst_instances = fetch_repo_instances(*dst_loc[:2], 'annotation')
    if dst_loc.syn_instance not in dst_instances:
        logger.info(f"Creating instance '{dst_loc.syn_instance}'")
        create_instance(*dst_loc, 'annotation')

    # Check to see if the sync already exists; add it if necessary
    syn_info = fetch_instance_info(*dst_loc[:3])
    if len(syn_info["Base"]["Syncs"]) == 0:
        logger.info(
            f"Adding a sync to '{dst_loc.syn_instance}' from '{dst_loc.seg_instance}'"
        )
        post_sync(*dst_loc[:3], [dst_loc.seg_instance])
    elif syn_info["Base"]["Syncs"][0] != dst_loc.seg_instance:
        other_seg = syn_info["Base"]["Syncs"][0]
        raise RuntimeError(
            f"Can't create a sync to '{dst_loc.seg_instance}'. "
            f"Your instance is already sync'd to a different segmentation: {other_seg}"
        )

    # Fetch segmentation extents
    bounding_box_zyx = fetch_volume_box(*src_loc[:2],
                                        src_loc.seg_instance).tolist()

    # Break into block-aligned chunks (boxes) that are long in the X direction
    # (optimal access pattern for dvid read/write)
    boxes = boxes_from_grid(bounding_box_zyx, (256, 256, 6400), clipped=True)

    # Use a process pool to copy the chunks in parallel.
    fn = partial(copy_syn_blocks, src_loc, dst_loc)
    compute_parallel(fn, boxes, processes=processes, ordered=False)
def extract_fragments(edges_df, bois, processes):
    """
    For each connected component group (pre-labeled) in the given DataFrame,
    Search for paths that can connect the groups's BOIs to each other,
    possibly passing through non-BOI nodes in the group.
    
    Returns:
        dict {group_cc: [fragment, fragment, ...]}
        where each fragment is a tuple of N body IDs which form a path of
        adjacent bodies, with a BOI on each end (first node/last node) of
        the path, and non-BOIs for the intermediate nodes (if any).
    """
    assert isinstance(bois, set)
    assert edges_df.duplicated(['group', 'label_a', 'label_b']).sum() == 0

    def _prepare_group(group_cc, cc_df):
        group_bois = bois & set(cc_df[['label_a', 'label_b'
                                       ]].values.reshape(-1))
        return group_cc, cc_df, group_bois

    with Timer("Extracting fragments from each group", logger):
        num_groups = edges_df['group_cc'].nunique()
        group_and_bois = starmap(_prepare_group, edges_df.groupby('group_cc'))

        cc_and_frags = compute_parallel(extract_fragments_for_cc,
                                        group_and_bois,
                                        1000,
                                        processes=processes,
                                        ordered=False,
                                        leave_progress=True,
                                        total=num_groups,
                                        starmap=True)

    fragments = dict(cc_and_frags)
    num_fragments = sum(len(frags) for frags in fragments.values())
    logger.info(f"Extracted {num_fragments} fragments")
    return fragments
def main():
    RESULTS_PKL_PATH = sys.argv[1]
    if len(sys.argv) == 3:
        PROCESSES = int(sys.argv[2])
    else:
        PROCESSES = 4

    # Calculate the difference in resolution between the stored mito segmentation and neuron segmenation.
    # If they differ, it must be by a power of 2.
    mito_res = fetch_info(*MITO_SEG)["Extended"]["VoxelSize"][0]
    assert mito_res % NEIGHBORHOOD_RES == 0
    assert np.log2(mito_res / NEIGHBORHOOD_RES) == int(np.log2(mito_res / NEIGHBORHOOD_RES)), \
        "This script assumes that the mito resolution and neighborhood resolution differ by a power of 2."
    mito_res_scale_diff = int(np.log2(mito_res // NEIGHBORHOOD_RES))

    with open(RESULTS_PKL_PATH, 'rb') as f:
        mc_df = pickle.load(f)

    new_names = {col: col.replace(' ', '_') for col in mc_df.columns}
    new_names['result'] = 'proofreader_count'
    mc_df = mc_df.rename(columns=new_names)

    print("Evaluating mito count results")
    results = compute_parallel(partial(_task_results, mito_res_scale_diff),
                               iter_batches(
                                   mc_df.drop_duplicates('neighborhood_id'),
                                   1),
                               total=len(mc_df),
                               processes=PROCESSES,
                               leave_progress=True,
                               ordered=False)

    cols = [
        'neighborhood_id', 'neighborhood_origin', 'proofreader_count',
        'mito_id_count', 'mito_ids', 'mito_sizes', 'num_ccs', 'mito_cc_ids',
        'mito_cc_sizes', 'ng_link'
    ]

    df = pd.DataFrame(results, columns=cols)

    # Add columns for cell type (from neuprint)
    print("Fetching neuron cell types")
    origins_df = pd.DataFrame(df['neighborhood_origin'].tolist(),
                              columns=[*'xyz'])
    df['body'] = fetch_labels_batched(*NEURON_SEG,
                                      origins_df[[*'zyx']].values,
                                      processes=8)
    neurons_df, _ = fetch_neurons(df['body'].unique())
    neurons_df = neurons_df.rename(columns={
        'bodyId': 'body',
        'type': 'body_type',
        'instance': 'body_instance'
    })
    df = df.merge(neurons_df[['body', 'body_type', 'body_instance']],
                  'left',
                  on='body')
    df['body_type'].fillna("", inplace=True)
    df['body_instance'].fillna("", inplace=True)

    # Append roi column
    print("Determining ROIs")
    determine_point_rois(*NEURON_SEG[:2], NEUPRINT_CLIENT.primary_rois,
                         origins_df)
    df['roi'] = origins_df['roi']

    # Results only
    path = 'mito-seg-counts.pkl'
    print(f"Writing {path}")
    with open(path, 'wb') as f:
        pickle.dump(df, f)

    path = 'mito-seg-counts.tab-delimited.csv'
    print(f"Writing {path}")
    df.to_csv(path, sep='\t', header=True, index=False)

    # Full results (with task info columns)
    df = df.merge(
        mc_df.drop(columns=['neighborhood_origin', 'proofreader_count']),
        'left',
        on='neighborhood_id')

    path = 'full-results-with-mito-seg-counts.pkl'
    print(f"Writing {path}")
    with open(path, 'wb') as f:
        pickle.dump(df, f)

    path = 'full-results-with-mito-seg-counts.tab-delimited.csv'
    print(f"Writing {path}")
    df.to_csv(path, sep='\t', header=True, index=False)

    print("DONE")
Exemple #10
0
def correct_centroids(config,
                      stats_df,
                      check_scale=0,
                      verify=False,
                      threads=0,
                      processes=8):
    import numpy as np
    import pandas as pd

    from neuclease.util import compute_parallel, Timer
    from flyemflows.volumes import VolumeService

    with Timer("Pre-sorting points by block", logger):
        stats_df['bz'] = stats_df['by'] = stats_df['bx'] = np.int32(0)
        stats_df[['bz', 'by', 'bx']] = stats_df[[*'zyx']] // 64
        stats_df.sort_values(['bz', 'by', 'bx'], inplace=True)
        stats_df.drop(columns=['bz', 'by', 'bx'], inplace=True)

    if config['mito-sparsevol-source'] is not None:
        sparsevol_source = VolumeService.create_from_config(
            config['mito-sparsevol-source'])
        point_source = sparsevol_source
    else:
        sparsevol_source = None
        point_source = None

    if config['mito-point-source']:
        point_source = VolumeService.create_from_config(
            config['mito-point-source'])

    assert point_source or sparsevol_source, \
        "You must provide either a point-source or sparsevol-source."

    stats_df['centroid_label'] = sample_labels(point_source, stats_df,
                                               check_scale, threads, processes)

    mismatched_mitos = stats_df.query('centroid_label != mito_id')

    logger.info(
        f"Correcting {len(mismatched_mitos)} mismatched mito centroids")

    if sparsevol_source:
        _find_mito = partial(find_mito_from_sparsevol,
                             *sparsevol_source.instance_triple)
        mitos_and_coords = compute_parallel(_find_mito,
                                            mismatched_mitos.index,
                                            ordered=False,
                                            threads=threads,
                                            processes=processes)
    else:
        _find_mito = partial(find_mito_from_seg, point_source, check_scale)
        mismatched_rows = mismatched_mitos.reset_index()[['mito_id',
                                                          *'zyx']].astype(
                                                              np.int64).values
        mitos_and_coords = compute_parallel(_find_mito,
                                            mismatched_rows,
                                            starmap=True,
                                            ordered=False,
                                            threads=threads,
                                            processes=processes)

    corrected_df = pd.DataFrame(mitos_and_coords,
                                columns=['mito_id',
                                         *'zyx']).set_index('mito_id')
    stats_df.loc[corrected_df.index, [*'zyx']] = corrected_df[[*'zyx']]

    stats_df['centroid_type'] = 'exact'
    stats_df.loc[corrected_df.index, 'centroid_type'] = 'adjusted'

    # Sanity check: they should all be correct now!
    if verify:
        new_labels = sample_labels(point_source,
                                   stats_df.loc[mismatched_mitos.index],
                                   check_scale, threads, processes)
        if (new_labels != mismatched_mitos.index).any():
            logger.error("Some mitos remained mismstached!")

    return stats_df
Exemple #11
0
def correct_centroids(config, stats_df, check_scale=0, verify=False, threads=0, processes=8):
    import numpy as np
    import pandas as pd

    from neuclease.util import tqdm_proxy, compute_parallel, Timer
    from neuclease.dvid import fetch_labels_batched
    from flyemflows.volumes import VolumeService, DvidVolumeService

    with Timer("Pre-sorting points by block", logger):
        stats_df['bz'] = stats_df['by'] = stats_df['bx'] = np.int32(0)
        stats_df[['bz', 'by', 'bx']] = stats_df[[*'zyx']] // 64
        stats_df.sort_values(['bz', 'by', 'bx'], inplace=True)
        stats_df.drop(columns=['bz', 'by', 'bx'], inplace=True)

    sparsevol_source = VolumeService.create_from_config(config['mito-sparsevol-source'])
    if config['mito-point-source'] is None:
        point_source = sparsevol_source
    else:
        point_source = VolumeService.create_from_config(config['mito-point-source'])

    if isinstance(point_source, DvidVolumeService):
        stats_df['centroid_label'] = fetch_labels_batched(*point_source.instance_triple,
                                                          stats_df[[*'zyx']] // (2**check_scale),
                                                          supervoxels=point_source.supervoxels,
                                                          scale=check_scale,
                                                          batch_size=1000,
                                                          threads=threads,
                                                          processes=processes)
    else:
        import multiprocessing as mp
        import dask
        from dask.diagnostics import ProgressBar

        if threads:
            pool = mp.pool.ThreadPool(threads)
        else:
            pool = mp.pool.Pool(processes)

        dask.config.set(scheduler='processes')
        with pool, dask.config.set(pool=pool), ProgressBar():
            centroids = stats_df[[*'zyx']] // (2**check_scale)
            stats_df['centroid_label'] = point_source.sample_labels( centroids, scale=check_scale )

    mismatched_mitos = stats_df.query('centroid_label != mito_id').index

    logger.info(f"Correcting {len(mismatched_mitos)} mismatched mito centroids")
    _find_mito = partial(find_mito, *sparsevol_source.instance_triple)
    mitos_and_coords = compute_parallel(_find_mito, mismatched_mitos, ordered=False, threads=threads, processes=processes)
    corrected_df = pd.DataFrame(mitos_and_coords, columns=['mito_id', *'zyx']).set_index('mito_id')
    stats_df.loc[corrected_df.index, [*'zyx']] = corrected_df[[*'zyx']]
    stats_df.loc[corrected_df.index, 'centroid_type'] = 'adjusted'

    # Sanity check: they should all be correct now!
    if verify:
        new_centroids = stats_df.loc[mismatched_mitos, [*'zyx']].values
        new_labels = fetch_labels_batched(*sparsevol_source.instance_triple,
                                          new_centroids,
                                          supervoxels=True,
                                          threads=threads,
                                          processes=processes)

        if (new_labels != mismatched_mitos).any():
            logger.error("Some mitos remained mismstached!")

    return stats_df
for line in bodyList:
   if line[0].isdigit():
      bodyID = line.rstrip('\n')
      group_list.append(int(bodyID))
      body_count += 1
      if body_count == 1000:
         body_groups.append(group_list)
         group_list = []
         body_count = 0

if len(group_list) > 0:
   body_groups.append(group_list)

PROCESSES = 15
def get_sizes(label_ids):
   try:
      sizes_pd = fetch_sizes(*master_seg, label_ids, supervoxels=False)
   except HTTPError:
      s_empty_pd = pd.Series(index=label_ids, data=-1, dtype=int)
      s_empty_pd.name = 'size'
      s_empty_pd.index.name = 'body'
      return(s_empty_pd)
   else:
      return(sizes_pd)

body_sizes_df_list = compute_parallel(get_sizes, body_groups, chunksize=100, processes=PROCESSES, ordered=False)

body_sizes_pd = pd.concat(body_sizes_df_list)

body_sizes_pd.to_csv(out_file, index=True)
Exemple #13
0
def fetch_body_edge_table(cleave_server, dvid_server, uuid, instance, body):
    dvid_server, dvid_port = dvid_server.split(':')

    if not cleave_server.startswith('http'):
        cleave_server = 'http://' + cleave_server

    data = {
        "body-id": body,
        "port": dvid_port,
        "server": dvid_server,
        "uuid": uuid,
        "segmentation-instance": instance,
        "user": "******"
    }

    r = requests.post(f'{cleave_server}/body-edge-table', json=data)
    r.raise_for_status()

    df = pd.read_csv(BytesIO(r.content), header=0)
    df = df.astype({'id_a': np.uint64, 'id_b': np.uint64, 'score': np.float32})
    return df


def warm_body(body):
    fetch_body_edge_table(CLEAVE_SERVER, DVID_SERVER, DVID_UUID,
                          'segmentation', body)


bodies = pd.read_csv(BODY_CSV)['body']
_ = compute_parallel(warm_body, bodies, threads=THREADS, ordered=False)
def neuron_mito_stats(seg_src, mito_cc_src, mito_class_src, body_id, scale=0, min_size=0, search_radius=50, processes=1):
    from functools import partial
    import numpy as np
    import pandas as pd

    from neuclease.util import compute_parallel
    from neuclease.dvid import fetch_sparsevol_coarse, resolve_ref, fetch_labels, fetch_labelmap_voxels

    seg_src[1] = resolve_ref(*seg_src[:2])
    mito_cc_src[1] = resolve_ref(*mito_cc_src[:2])
    mito_class_src[1] = resolve_ref(*mito_class_src[:2])

    # Fetch block coords; re-scale for the analysis scale
    block_coords = (2**6) * fetch_sparsevol_coarse(*seg_src, body_id)
    bc_df = pd.DataFrame(block_coords, columns=[*'zyx'])
    bc_df[[*'zyx']] //= 2**scale
    block_coords = bc_df.drop_duplicates().values

    #
    # Blockwise stats
    #
    block_fn = partial(_process_block, seg_src, mito_cc_src, mito_class_src, body_id, scale)
    block_tables = compute_parallel(block_fn, block_coords, processes=processes)
    block_tables = [*filter(lambda t: t is not None, block_tables)]
    #
    # Combine stats
    #
    full_table = pd.concat(block_tables, sort=True).fillna(0)
    class_cols = [*filter(lambda c: c.startswith('class'), full_table.columns)]
    full_table = full_table.astype({c: np.int32 for c in class_cols})

    # Weight each block centroid by the block's voxel count before taking the mean
    full_table[[*'zyx']] *= full_table[['total_size']].values
    stats_df = full_table.groupby('mito_id').sum()
    stats_df[[*'zyx']] /= stats_df[['total_size']].values

    # Drop tiny mitos
    stats_df = stats_df.query("total_size >= @min_size").copy()

    # Assume all centroids are 'exact' by default (overwritten below if necessary)
    stats_df['centroid_type'] = 'exact'

    # Include a column for 'body' even thought its the same on every row,
    # just as a convenience for concatenating these results with the results
    # from other bodies if desired.
    stats_df['body'] = body_id

    stats_df = stats_df.astype({a: np.int32 for a in 'zyx'})
    stats_df = stats_df[['body', *'xyz', 'total_size', *class_cols, 'centroid_type']]

    #
    # Check for centroids that fall outside of the mito,
    # and adjust them if necessary.
    #
    centroid_mitos = fetch_labels(*mito_cc_src, stats_df[[*'zyx']].values, scale=scale)
    mismatches = stats_df.index[(stats_df.index != centroid_mitos)]

    if len(mismatches) == 0:
        return stats_df

    logger.warning("Some mitochondria centroids do not lie within the mitochondria itself. "
                   "Searching for pseudo-centroids.")

    # construct field of distances from the central voxel
    sr = search_radius
    cz, cy, cx = np.ogrid[-sr:sr+1, -sr:sr+1, -sr:sr+1]
    distances = np.sqrt(cz**2 + cy**2 + cx**2)

    pseudo_centroids = []
    error_mito_ids = []
    for row in stats_df.loc[mismatches].itertuples():
        mito_id = row.Index
        centroid = np.array((row.z, row.y, row.x))
        box = (centroid - sr, 1 + centroid + sr)
        mito_mask = (mito_id == fetch_labelmap_voxels(*mito_cc_src, box, scale))

        if not mito_mask.any():
            pseudo_centroids.append((row.z, row.y, row.x))
            error_mito_ids.append(mito_id)
            continue

        # Find minimum distance
        masked_distances = np.where(mito_mask, distances, np.inf)
        new_centroid = np.unravel_index(np.argmin(masked_distances), masked_distances.shape)
        new_centroid = np.array(new_centroid) + centroid - sr
        pseudo_centroids.append(new_centroid)

    stats_df.loc[mismatches, ['z', 'y', 'x']] = np.array(pseudo_centroids, dtype=np.int32)
    stats_df.loc[mismatches, 'centroid_type'] = 'adjusted'
    stats_df.loc[error_mito_ids, 'centroid_type'] = 'error'

    if error_mito_ids:
        logger.warning("Some mitochondria pseudo-centroids could not be found.")

    stats_df = stats_df.astype({a: np.int32 for a in 'zyx'})
    return stats_df