Beispiel #1
0
def test_post_mappings(labelmap_setup):
    """
    Test the wrapper function for the /mappings DVID API.
    """
    dvid_server, dvid_repo, _merge_table_path, _mapping_path, _supervoxel_vol = labelmap_setup
    instance_info = DvidInstanceInfo(dvid_server, dvid_repo, 'segmentation')

    # Fetch the original mapping
    orig_mapping = fetch_mappings(*instance_info).sort_index()
    assert (orig_mapping.index == [2, 3, 4, 5]).all()
    assert (orig_mapping == 1).all()  # see initialization in conftest.py

    # Make sure post_mappings does not REQUIRE the Series to be named
    orig_mapping.index.name = 'barfoo'
    orig_mapping.name = 'foobar'

    # Now post a new mapping and read it back.
    new_mapping = orig_mapping.copy()
    new_mapping[:] = 2
    new_mapping.sort_index(inplace=True)

    # Post all but sv 5
    post_mappings(*instance_info, new_mapping.iloc[:-1], mutid=1)
    fetched_mapping = fetch_mappings(*instance_info).sort_index()
    assert (fetched_mapping.index == [3, 4, 5]).all()
    assert (fetched_mapping.iloc[:-1] == 2).all()
    assert (fetched_mapping.iloc[-1:] == 1).all()

    # Now post sv 5, too
    post_mappings(*instance_info, new_mapping.iloc[-1:], mutid=1)
    fetched_mapping = fetch_mappings(*instance_info).sort_index()
    assert (fetched_mapping.index == [3, 4, 5]).all()
    assert (fetched_mapping == 2).all()

    # Try batched
    new_mapping = pd.Series(index=[1, 2, 3, 4, 5], data=[1, 1, 1, 2, 2])
    post_mappings(*instance_info, new_mapping, mutid=1, batch_size=4)
    fetched_mapping = fetch_mappings(*instance_info).sort_index()
    assert (fetched_mapping.index == [2, 3, 4, 5]).all()
    assert (fetched_mapping == [1, 1, 2,
                                2]).all(), f"{fetched_mapping} != {[1,1,2,2]}"

    # Restore the original mapping
    post_mappings(*instance_info, orig_mapping, 1)
    fetched_mapping = fetch_mappings(*instance_info).sort_index()
    assert (fetched_mapping.index == [2, 3, 4, 5]).all()
    assert (fetched_mapping == 1).all()
Beispiel #2
0
def test_fetch_mappings(labelmap_setup):
    """
    Test the wrapper function for the /mappings DVID API.
    """
    dvid_server, dvid_repo, _merge_table_path, _mapping_path, _supervoxel_vol = labelmap_setup
    instance_info = DvidInstanceInfo(dvid_server, dvid_repo, 'segmentation')

    mapping = fetch_mappings(*instance_info)
    assert isinstance(mapping, pd.Series)
    assert mapping.index.name == 'sv'
    assert mapping.name == 'body'
    assert (sorted(mapping.index) == [2, 3, 4, 5]
            )  # Does not include 'identity' row for SV 1. See docstring.
    assert (mapping == 1).all()  # see initialization in conftest.py
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--split-into-batches', type=int,
                        help='If given, also split the body stats into this many batches of roughly equal size')
    parser.add_argument('server')
    parser.add_argument('src_uuid')
    parser.add_argument('labelmap_instance')
    parser.add_argument('supervoxel_block_stats_h5',
                        help=f'An HDF5 file with a single dataset "stats", with dtype: {STATS_DTYPE[1:]} (Note: No column for body_id)')
    args = parser.parse_args()

    configure_default_logging()
    initialize_excepthook()
    (block_sv_stats, _presorted_by, _agglo_path) = load_stats_h5_to_records(args.supervoxel_block_stats_h5)

    src_info = (args.server, args.src_uuid, args.labelmap_instance)
    mapping = fetch_mappings(*src_info)

    assert isinstance(mapping, pd.Series)
    mapping_df = mapping.reset_index().rename(columns={'sv': 'segment_id', 'body': 'body_id'})

    # sorts in-place, and saves a copy to hdf5
    sort_block_stats( block_sv_stats,
                      mapping_df,
                      args.supervoxel_block_stats_h5[:-3] + '-sorted-by-body.h5',
                      '<fetched-from-dvid>')
    
    if args.split_into_batches:
        num_batches = args.split_into_batches
        batch_size = int(np.ceil(len(block_sv_stats) / args.split_into_batches))
        logger.info(f"Splitting into {args.split_into_batches} batches of size ~{batch_size}")
        os.makedirs('stats-batches', exist_ok=True)
        
        body_spans = groupby_spans_presorted(block_sv_stats['body_id'][:, None])
        for batch_index, batch_spans in enumerate(tqdm_proxy(iter_batches(body_spans, batch_size))):
            span_start, span_stop = batch_spans[0][0], batch_spans[-1][1]
            batch_stats = block_sv_stats[span_start:span_stop]
            digits = int(np.ceil(np.log10(num_batches)))
            batch_path = ('stats-batches/stats-batch-{:0' + str(digits) + 'd}.h5').format(batch_index)
            save_stats(batch_stats, batch_path)
    
    logger.info("DONE sorting stats by body")
    def sparse_brick_coords_for_labels(self, labels, clip=True):
        """
        Return a DataFrame indicating the brick
        coordinates (starting corner) that encompass the given labels.

        Args:
            labels:
                A list of body IDs (if ``self.supervoxels`` is False),
                or supervoxel IDs (if ``self.supervoxels`` is True).

            clip:
                If True, filter the results to exclude any coordinates
                that fall outside this service's bounding-box.
                Otherwise, all brick coordinates that encompass the given labels
                will be returned, whether or not they fall within the bounding box.

        Returns:
            DataFrame with columns [z,y,x,label],
            where z,y,x represents the starting corner (in full-res coordinates)
            of a brick that contains the label.
        """
        assert not isinstance(labels,
                              set), "Pass labels as a list or array, not a set"
        labels = pd.unique(labels)
        is_supervoxels = self.supervoxels
        brick_shape = self.preferred_message_shape
        assert (brick_shape % self.block_width == 0).all(), \
            ("Brick shape ('preferred-message-shape') must be a multiple of the "
             f"block width ({self.block_width}) in all dimensions, not {brick_shape}")

        bad_labels = []

        if not is_supervoxels:
            # No supervoxel filtering.
            # Sort by body, since that should be slightly nicer for dvid performance.
            bodies_and_svs = {label: None for label in sorted(labels)}
        else:
            # Arbitrary heuristic for whether to do the body-lookups on DVID or on the client.
            if len(labels) < 100_000:
                # If we're only dealing with a few supervoxels,
                # ask dvid to map them to bodies for us.
                mapping = fetch_mapping(*self.instance_triple,
                                        labels,
                                        as_series=True)
            else:
                # If we're dealing with a lot of supervoxels, ask for
                # the entire mapping, and look up the bodies ourselves.
                complete_mapping = fetch_mappings(*self.instance_triple)
                mapper = LabelMapper(complete_mapping.index.values,
                                     complete_mapping.values)

                labels = np.asarray(labels, np.uint64)
                bodies = mapper.apply(labels, True)
                mapping = pd.Series(index=labels, data=bodies, name='body')
                mapping.index.rename('sv', inplace=True)

            bad_svs = mapping[mapping == 0]
            bad_labels.extend(bad_svs.index.tolist())

            # Group by body
            mapping = mapping[mapping != 0]
            grouped_svs = mapping.reset_index().groupby('body').agg(
                {'sv': list})['sv']

            # Sort by body, since that should be slightly nicer for dvid performance.
            bodies_and_svs = grouped_svs.sort_index().to_dict()

        # Extract these to avoid pickling 'self' (just for speed)
        server, uuid, instance = self.instance_triple
        if self._use_resource_manager_for_sparse_coords:
            mgr = self.resource_manager_client
        else:
            mgr = ResourceManagerClient("", 0)

        def fetch_brick_coords(body, supervoxel_subset):
            """
            Fetch the block coordinates for the given body,
            filter them for the given supervoxels (if any),
            and convert the block coordinates to brick coordinates.
            """
            assert is_supervoxels or supervoxel_subset is None

            try:
                with mgr.access_context(server, True, 1, 1):
                    labelindex = fetch_labelindex(server, uuid, instance, body,
                                                  'protobuf')
                coords_df = convert_labelindex_to_pandas(labelindex).blocks

            except HTTPError as ex:
                if (ex.response is not None
                        and ex.response.status_code == 404):
                    return (body, None)
                raise
            except RuntimeError as ex:
                if 'does not map to any body' in str(ex):
                    return (body, None)
                raise

            if len(coords_df) == 0:
                return (body, None)

            if is_supervoxels:
                supervoxel_subset = set(supervoxel_subset)
                coords_df = coords_df.query('sv in @supervoxel_subset').copy()

            coords_df[['z', 'y', 'x']] //= brick_shape
            coords_df['body'] = np.uint64(body)
            coords_df.drop_duplicates(inplace=True)
            return (body, coords_df)

        def fetch_and_concatenate_brick_coords(bodies_and_supervoxels):
            """
            To reduce the number of tiny DataFrames collected to the driver,
            it's best to concatenate the partitions first, on the workers,
            rather than a straightforward call to starmap(fetch_brick_coords).

            Hence, this function that consolidates each partition.
            """
            bad_bodies = []
            coord_dfs = []
            for (body, supervoxel_subset) in bodies_and_supervoxels:
                _, coords_df = fetch_brick_coords(body, supervoxel_subset)
                if coords_df is None:
                    bad_bodies.append(body)
                else:
                    coord_dfs.append(coords_df)
                    del coords_df

            if coord_dfs:
                return [(pd.concat(coord_dfs, ignore_index=True), bad_bodies)]
            else:
                return [(None, bad_bodies)]

        with Timer(
                f"Fetching coarse sparsevols for {len(labels)} labels ({len(bodies_and_svs)} bodies)",
                logger=logger):
            import dask.bag as db
            coords_and_bad_bodies = (
                db.from_sequence(
                    bodies_and_svs.items(), npartitions=4096
                )  # Instead of fancy heuristics, just pick 4096
                .map_partitions(fetch_and_concatenate_brick_coords).compute())

        coords_df_partitions, bad_body_partitions = zip(*coords_and_bad_bodies)

        for body in chain(*bad_body_partitions):
            if is_supervoxels:
                bad_labels.extend(bodies_and_svs[body])
            else:
                bad_labels.append(body)

        if bad_labels:
            name = 'sv' if is_supervoxels else 'body'
            pd.Series(bad_labels,
                      name=name).to_csv('labels-without-sparsevols.csv',
                                        index=False,
                                        header=True)
            if len(bad_labels) < 100:
                msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels: {bad_labels}"
            else:
                msg = f"Could not obtain coarse sparsevol for {len(bad_labels)} labels. See labels-without-sparsevols.csv"

            logger.warning(msg)

        coords_df_partitions = list(
            filter(lambda df: df is not None, coords_df_partitions))
        if len(coords_df_partitions) == 0:
            raise RuntimeError(
                "Could not find bricks for any of the given labels")

        coords_df = pd.concat(coords_df_partitions, ignore_index=True)

        if self.supervoxels:
            coords_df['label'] = coords_df['sv']
        else:
            coords_df['label'] = coords_df['body']

        coords_df.drop_duplicates(['z', 'y', 'x', 'label'], inplace=True)
        coords_df[['z', 'y', 'x']] *= brick_shape

        if clip:
            # Keep if the last pixel in the brick is to the right of the bounding-box start
            # and the first pixel in the brick is to the left of the bounding-box stop
            keep = (coords_df[['z', 'y', 'x']] + brick_shape >
                    self.bounding_box_zyx[0]).all(axis=1)
            keep &= (coords_df[['z', 'y', 'x']] <
                     self.bounding_box_zyx[1]).all(axis=1)
            coords_df = coords_df.loc[keep]

        return coords_df[['z', 'y', 'x', 'label']]
                            *,
                            threads=0,
                            processes=0,
                            last_mutid=None,
                            mapping=None):
    assert not (threads and processes), \
        "Use threads or processes (or neither), but not both."
    if last_mutid is None:
        last_mutid = fetch_repo_info(*src_info[:2])["MutationID"]

    (block_sv_stats, presorted_by,
     _agglo_path) = load_stats_h5_to_records(erased_block_stats_h5)

    if presorted_by != 'body_id':
        if mapping is None:
            mapping = fetch_mappings(*src_info)

        assert isinstance(mapping, pd.Series)
        mapping_df = mapping.reset_index().rename(columns={
            'sv': 'segment_id',
            'body': 'body_id'
        })

        # sorts in-place, and saves a copy to hdf5
        sort_block_stats(block_sv_stats, mapping_df,
                         erased_block_stats_h5[:-3] + '-sorted-by-body.h5',
                         '<fetched-from-dvid>')

    if threads > 0:
        pool = ThreadPool(threads)
    elif processes > 0:
Beispiel #6
0
def fetch_roi_synapses(server,
                       uuid,
                       synapses_instance,
                       rois,
                       fetch_labels=False,
                       return_partners=False,
                       processes=16):
    """
    Fetch the coordinates and (optionally) body labels for 
    all synapses that fall within the given ROIs.
    
    Args:
    
        server:
            DVID server, e.g. 'emdata4:8900'
        
        uuid:
            DVID uuid, e.g. 'abc9'
        
        synapses_instance:
            DVID synapses instance name, e.g. 'synapses'
        
        rois:
            A single DVID ROI instance names or a list of them, e.g. 'EB' or ['EB', 'FB']
        
        fetch_labels:
            If True, also fetch the supervoxel and body label underneath each synapse,
            returned in columns 'sv' and 'body'.
            
        return_partners:
            If True, also return the partners table.

        processes:
            How many parallel processes to use when fetching synapses and supervoxel labels.
    
    Returns:
        pandas DataFrame with columns:
        ``['z', 'y', 'x', 'kind', 'conf']`` and ``['sv', 'body']`` (if ``fetch_labels=True``)
        If return_partners is True, also return the partners table.

    Example:
        df = fetch_roi_synapses('emdata4:8900', '3c281', 'synapses', ['PB(L5)', 'PB(L7)'], True, 8)
    """
    # Late imports to avoid circular imports in dvid/__init__
    from neuclease.dvid import fetch_combined_roi_volume, determine_point_rois, fetch_labels_batched, fetch_mapping, fetch_mappings

    assert rois, "No rois provided, result would be empty. Is that what you meant?"

    if isinstance(rois, str):
        rois = [rois]

    # Determine name of the segmentation instance that's
    # associated with the given synapses instance.
    syn_info = fetch_instance_info(server, uuid, synapses_instance)
    seg_instance = syn_info["Base"]["Syncs"][0]

    logger.info(f"Fetching mask for ROIs: {rois}")
    # Fetch the ROI as a low-res array (scale 5, i.e. 32-px resolution)
    roi_vol_s5, roi_box_s5, overlapping_pairs = fetch_combined_roi_volume(
        server, uuid, rois)

    if len(overlapping_pairs) > 0:
        logger.warning(
            "Some ROIs overlapped and are thus not completely represented in the output:\n"
            f"{overlapping_pairs}")

    # Convert to full-res box
    roi_box = (2**5) * roi_box_s5

    # fetch_synapses_in_batches() requires a box that is 64-px-aligned
    roi_box = round_box(roi_box, 64, 'out')

    logger.info("Fetching synapse points")
    # points_df is a DataFrame with columns for [z,y,x]
    points_df, partners_df = fetch_synapses_in_batches(server,
                                                       uuid,
                                                       synapses_instance,
                                                       roi_box,
                                                       processes=processes)

    # Append a 'roi_name' column to points_df
    logger.info("Labeling ROI for each point")
    determine_point_rois(server, uuid, rois, points_df, roi_vol_s5, roi_box_s5)

    logger.info("Discarding points that don't overlap with the roi")
    rois = {*rois}
    points_df = points_df.query('roi in @rois').copy()

    columns = ['z', 'y', 'x', 'kind', 'conf', 'roi_label', 'roi']

    if fetch_labels:
        logger.info("Fetching supervoxel under each point")
        svs = fetch_labels_batched(server,
                                   uuid,
                                   seg_instance,
                                   points_df[['z', 'y', 'x']].values,
                                   supervoxels=True,
                                   processes=processes)

        with Timer("Mapping supervoxels to bodies", logger):
            # Arbitrary heuristic for whether to do the
            # body-lookups on DVID or on the client.
            if len(svs) < 100_000:
                bodies = fetch_mapping(server, uuid, seg_instance, svs)
            else:
                mapping = fetch_mappings(server, uuid, seg_instance)
                mapper = LabelMapper(mapping.index.values, mapping.values)
                bodies = mapper.apply(svs, True)

        points_df['sv'] = svs
        points_df['body'] = bodies
        columns += ['body', 'sv']

    if return_partners:
        # Filter
        #partners_df = partners_df.query('post_id in @points_df.index and pre_id in @points_df.index').copy()

        # Faster filter (via merge)
        partners_df = partners_df.merge(points_df[[]],
                                        'inner',
                                        left_on='pre_id',
                                        right_index=True)
        partners_df = partners_df.merge(points_df[[]],
                                        'inner',
                                        left_on='post_id',
                                        right_index=True)
        return points_df[columns], partners_df
    else:
        return points_df[columns]