示例#1
0
def export_sparsevol(server, uuid, instance, neurons_df, scale=5, format='tiff', output_dir='.'):
    import os
    import vigra
    import numpy as np

    from neuclease.util import round_box, tqdm_proxy
    from neuclease.dvid import fetch_sparsevol, resolve_ref, fetch_volume_box, box_to_slicing

    uuid = resolve_ref(server, uuid)

    # Determine the segmentation bounding box at the given scale,
    # which is used as the mask shape.
    seg = (server, uuid, instance)
    box = round_box(fetch_volume_box(*seg), 64, 'out')
    box[0] = (0,0,0)
    box_scaled = box // 2**scale

    # How many digits will we need in each slice file name?
    digits = int(np.ceil(np.log10(box_scaled[1, 0])))

    # Export a mask stack for each group.
    groups = neurons_df.groupby('group', sort=False)
    num_groups = neurons_df['group'].nunique()
    group_prog = tqdm_proxy(groups, total=num_groups)
    for group, df in group_prog:
        group_prog.write(f'Group "{group}": Assembling mask')
        group_mask = np.zeros(box_scaled[1], dtype=bool)
        group_mask = vigra.taggedView(group_mask, 'zyx')

        # Overlay each body mask in the current group
        for body in tqdm_proxy(df['body'], leave=False):
            body_mask, mask_box = fetch_sparsevol(*seg, body, scale=scale, format='mask')
            group_mask[box_to_slicing(*mask_box)] |= body_mask

        # Write out the slice files
        group_prog.write(f'Group "{group}": Writing slices')
        d = f'{output_dir}/{group}.stack'
        os.makedirs(d, exist_ok=True)
        for z in tqdm_proxy(range(group_mask.shape[0]), leave=False):
            p = ('{d}/{z:' + f'0{digits}' + 'd}.{f}').format(d=d, z=z, f=format)
            vigra.impex.writeImage(group_mask[z].astype(np.uint8), p)
示例#2
0
def _generate_and_store_mesh():
    try:
        dvid = request.args['dvid']
        body = request.args['body']
    except KeyError as ex:
        return Response(f"Missing required parameter: {ex.args[0]}", 400)

    segmentation = request.args.get('segmentation', 'segmentation')
    mesh_kv = request.args.get('mesh_kv', f'{segmentation}_meshes')

    uuid = request.args.get('uuid') or find_master(dvid)
    if not uuid:
        uuid = find_master(dvid)

    scale = request.args.get('scale')
    if scale is not None:
        scale = int(scale)

    smoothing = int(request.args.get('smoothing', 2))

    # Note: This is just the effective desired decimation assuming scale-1 data.
    # If we're forced to select a higher scale than scale-1, then we'll increase
    # this number to compensate.
    decimation = float(request.args.get('decimation', 0.1))

    user = request.args.get('u')
    user = user or request.args.get('user', "UNKNOWN")

    # TODO: The global cache of DVID sessions should store authentication info
    #       and use it as part of the key lookup, to avoid creating a new dvid
    #       session for every single cloud call!
    dvid_session = default_dvid_session('cloud-meshgen', user)
    auth = request.headers.get('Authorization')
    if auth:
        dvid_session = copy.deepcopy(dvid_session)
        dvid_session.headers['Authorization'] = auth

    with Timer(f"Body {body}: Fetching coarse sparsevol"):
        svc_ranges = fetch_sparsevol_coarse(dvid,
                                            uuid,
                                            segmentation,
                                            body,
                                            format='ranges',
                                            session=dvid_session)

    #svc_mask, _svc_box = fetch_sparsevol_coarse(dvid, uuid, segmentation, body, format='mask', session=dvid_session)
    #np.save(f'mask-{body}-svc.npy', svc_mask)

    box_s6 = rle_ranges_box(svc_ranges)
    box_s0 = box_s6 * (2**6)
    logger.info(f"Body {body}: Bounding box: {box_s0[:, ::-1].tolist()}")

    if scale is None:
        # Use scale 1 if possible or a higher scale
        # if necessary due to bounding-box RAM usage.
        scale = max(1, select_scale(box_s0))

    if scale > 1:
        # If we chose a low-res scale, then we
        # can reduce the decimation as needed.
        decimation = min(1.0, decimation * 4**(scale - 1))

    with Timer(f"Body {body}: Fetching scale-{scale} sparsevol"):
        mask, mask_box = fetch_sparsevol(dvid,
                                         uuid,
                                         segmentation,
                                         body,
                                         scale=scale,
                                         format='mask',
                                         session=dvid_session)
        # np.save(f'mask-{body}-s{scale}.npy', mask)

        # Pad with a thin halo of zeros to avoid holes in the mesh at the box boundary
        mask = np.pad(mask, 1)
        mask_box += [(-1, -1, -1), (1, 1, 1)]

    with Timer(f"Body {body}: Computing mesh"):
        # The 'ilastik' marching cubes implementation supports smoothing during mesh construction.
        mesh = Mesh.from_binary_vol(mask,
                                    mask_box * VOXEL_NM * (2**scale),
                                    smoothing_rounds=smoothing)

        logger.info(f"Body {body}: Decimating mesh at fraction {decimation}")
        mesh.simplify(decimation)

        logger.info(f"Body {body}: Preparing ngmesh")
        mesh_bytes = mesh.serialize(fmt='ngmesh')

    if scale > 2:
        logger.info(f"Body {body}: Not storing to dvid (scale > 2)")
    else:
        with Timer(
                f"Body {body}: Storing {body}.ngmesh in DVID ({len(mesh_bytes)/MB:.1f} MB)"
        ):
            try:
                post_key(dvid,
                         uuid,
                         mesh_kv,
                         f"{body}.ngmesh",
                         mesh_bytes,
                         session=dvid_session)
            except HTTPError as ex:
                err = ex.response.content.decode('utf-8')
                if 'locked node' in err:
                    logger.info(
                        "Body {body}: Not storing to dvid (uuid {uuid[:4]} is locked)."
                    )
                else:
                    logger.warning("Mesh could not be cached to dvid:\n{err}")

    r = make_response(mesh_bytes)
    r.headers.set('Content-Type', 'application/octet-stream')
    return r
示例#3
0
        def compute_mesh_and_write(body):
            with Timer() as timer:
                # Fetch the sparsevol to determine the bounding-box size (in scale-0 voxels)
                try:
                    with mgr_client.access_context(server, True, 1, 0):
                        # sparsevol-coarse is at scale-6
                        coords_s6 = fetch_sparsevol_coarse(
                            server, uuid, instance, body, is_supervoxels)
                except:
                    return (body, 0, 0, 0, 0.0, timer.seconds,
                            'error-sparsevol-coarse')

                box_s6 = np.array(
                    [coords_s6.min(axis=0), 1 + coords_s6.max(axis=0)])
                box_s0 = (2**6) * box_s6
                shape_s0 = (box_s0[1] - box_s0[0])
                box_voxels_s0 = np.prod(shape_s0.astype(float))

                # Determine the scale we'll use.
                # Solve for 'scale' in the following relationship:
                #
                #   box_voxels_s0/((2^scale)^3) <= max_box_voxels
                #
                scale = log2(pow(box_voxels_s0 / max_box_voxels, 1 / 3))
                scale = max(ceil(scale), min_scale)

                if scale > max_scale:
                    raise RuntimeError(
                        f"Can't compute mesh for body {body}. Bounding box is {box_s0[:, ::-1].tolist()}, "
                        f"which is too large to fit in desired RAM, even at scale {max_scale}"
                    )

                try:
                    with mgr_client.access_context(server, True, 1, 0):
                        coords = fetch_sparsevol(server,
                                                 uuid,
                                                 instance,
                                                 body,
                                                 is_supervoxels,
                                                 scale,
                                                 dtype=np.int16)
                except:
                    return (body, 0, 0, 0, 0.0, timer.seconds,
                            'error-sparsevol')

                box = box_s0 // (2**scale)
                coords -= box[0]
                num_voxels = len(coords)

                shape = box[1] - box[0]
                vol = np.zeros(shape, np.uint8)
                vol[(*coords.transpose(), )] = 1
                del coords

                try:
                    mesh = Mesh.from_binary_vol(vol, box_s0)
                except:
                    return (body, scale, num_voxels, 0, 0.0, timer.seconds,
                            'error-construction')

                del vol
                try:
                    mesh.laplacian_smooth(smoothing_iterations)
                except:
                    return (body, scale, num_voxels, 0.0,
                            len(mesh.vertices_zyx), timer.seconds,
                            'error-smoothing')

                fraction = decimation_fraction
                if scale > min_scale:
                    # Since we're starting from a lower resolution than the user requested,
                    # Reduce the decimation we're applying accordingly.
                    # Since meshes are 2D surfaces, we approximate the difference in
                    # vertexes as the SQUARE of the difference in resolution.
                    fraction *= (2**(scale - min_scale))**2
                    fraction = min(fraction, 1.0)

                try:
                    mesh.simplify(fraction, in_memory=True)
                except:
                    return (body, scale, num_voxels, 0.0,
                            len(mesh.vertices_zyx), timer.seconds,
                            'error-decimation')

                output_path = f'{options["output-directory"]}/{body}.{options["format"]}'
                mesh.serialize(output_path)

                return (body, scale, num_voxels, fraction,
                        len(mesh.vertices_zyx), timer.seconds, 'success')
示例#4
0
def autogen_points(input_seg,
                   count,
                   roi,
                   body,
                   tbars,
                   use_skeleton,
                   random_seed=None,
                   minimum_distance=0):
    """
    Generate a list of points within the input segmentation, based on the given criteria.
    See the main help text below for details.
    """
    if tbars and not body:
        sys.exit(
            "If you want to auto-generate tbar points, please specify a body.")

    if not tbars and not count:
        sys.exit(
            "You must supply a --count unless you are generating all tbars of a body."
        )

    if use_skeleton:
        if not body:
            sys.exit(
                "You must supply a body ID if you want to use a skeleton.")
        if tbars:
            sys.exit(
                "You can't select both tbar points and skeleton points.  Pick one or the other."
            )
        if not count and minimum_distance > 0:
            sys.exit(
                "You must supply a --count if you want skeleton point samples to respect the minimum distance."
            )
        if not count and not roi and minimum_distance == 0:
            logger.warning(
                "You are using all nodes of a skeleton without any ROI filter! Is that what you meant?"
            )

    rng = default_rng(random_seed)

    if tbars:
        logger.info(f"Fetching synapses for body {body}")
        syn_df = fetch_annotation_label(*input_seg[:2],
                                        'synapses',
                                        body,
                                        format='pandas')
        tbars = syn_df.query('kind == "PreSyn"')[[*'zyx']]

        if roi:
            logger.info(f"Filtering tbars for roi {roi}")
            determine_point_rois(*input_seg[:2], [roi], tbars)
            tbars = tbars.query('roi == @roi')[[*'zyx']]

        if minimum_distance:
            logger.info(
                f"Pruning close points from {len(tbars)} total tbar points")
            tbars = prune_close_pairs(tbars, minimum_distance, rng)
            logger.info(f"After pruning, {len(tbars)} tbars remain.")

        if count:
            count = min(count, len(tbars))
            logger.info(f"Sampling {count} tbars")
            choices = rng.choice(tbars.index, size=count, replace=False)
            tbars = tbars.loc[choices]

        logger.info(f"Returning {len(tbars)} tbar points")
        return tbars

    elif use_skeleton:
        assert body
        logger.info(f"Fetching skeleton for body {body}")
        skeleton_instance = f'{input_seg[2]}_skeletons'
        swc = fetch_key(*input_seg[:2], skeleton_instance, f'{body}_swc')
        skeleton_df = swc_to_dataframe(swc)
        skeleton_df['x'] = skeleton_df['x'].astype(int)
        skeleton_df['y'] = skeleton_df['y'].astype(int)
        skeleton_df['z'] = skeleton_df['z'].astype(int)

        if roi:
            logger.info(f"Filtering skeleton for roi {roi}")
            determine_point_rois(*input_seg[:2], [roi], skeleton_df)
            skeleton_df = skeleton_df.query('roi == @roi')[[*'zyx']]

        if minimum_distance:
            assert count
            # Distance-pruning is very expensive on a huge number of close points.
            # If skeleton is large, first reduce the workload by pre-selecting a
            # random sample of skeleton points, and prune more from there.
            if len(skeleton_df) > 10_000:
                # FIXME: random_state can't use rng until I upgrade to pandas 1.0
                skeleton_df = skeleton_df.sample(min(4 * count,
                                                     len(skeleton_df)),
                                                 random_state=None)
            logger.info(
                f"Pruning close points from {len(skeleton_df)} skeleton points"
            )
            prune_close_pairs(skeleton_df, minimum_distance, rng)
            logger.info(
                f"After pruning, {len(skeleton_df)} skeleton points remain.")

        if count:
            count = min(count, len(skeleton_df))
            logger.info(f"Sampling {count} skeleton points")
            choices = rng.choice(skeleton_df.index, size=count, replace=False)
            skeleton_df = skeleton_df.loc[choices]

        logger.info(f"Returning {len(skeleton_df)} skeleton points")
        return skeleton_df

    elif body:
        assert count
        if roi:
            # TODO: intersect the ranges with the ROI.
            raise NotImplementedError(
                "Sorry, I haven't yet implemented support for "
                "body+roi filtering.  Pick one or the other, "
                "or ask Stuart to fix this.")

        logger.info(f"Fetching sparsevol for body {body}")
        ranges = fetch_sparsevol(*input_seg, body, format='ranges')
        logger.info("Sampling from sparsevol")

        if minimum_distance > 0:
            # Sample 4x extra so we still have enough after pruning.
            points = sample_points_from_ranges(ranges, 4 * count, rng)
        else:
            points = sample_points_from_ranges(ranges, count, rng)

        points = pd.DataFrame(points, columns=[*'zyx'])

        if minimum_distance > 0:
            logger.info(f"Pruning close points from {len(points)} body points")
            prune_close_pairs(points, minimum_distance, rng)
            logger.info(f"After pruning, {len(points)} body points remain")

        points = points.iloc[:count]
        logger.info(f"Returning {len(points)} body points")
        return points

    elif roi:
        assert count
        logger.info(f"Fetching roi {roi}")
        roi_ranges = fetch_roi_roi(*input_seg[:2], roi, format='ranges')
        logger.info("Sampling from ranges")

        if minimum_distance > 0:
            # Sample 4x extra so we can prune some out if necessary.
            points_s5 = sample_points_from_ranges(roi_ranges, 4 * count, rng)
        else:
            points_s5 = sample_points_from_ranges(roi_ranges, count, rng)

        corners_s0 = points_s5 * (2**5)
        points_s0 = rng.integers(corners_s0, corners_s0 + (2**5))
        points = pd.DataFrame(points_s0, columns=[*'zyx'])

        if minimum_distance > 0:
            logger.info(f"Pruning close points from {len(points)} roi points")
            prune_close_pairs(points, minimum_distance, rng)
            logger.info(
                f"After pruning, points from {len(points)} roi points remain")

        points = points.iloc[:count]
        logger.info(f"Returning {len(points)} roi points")
        return points
    else:
        # No body or roi specified, just sample from the whole non-zero segmentation area
        assert count
        logger.info("Sampling random points from entire input segmentation")
        logger.info("Fetching low-res input volume")
        box_s6 = round_box(fetch_volume_box(*input_seg), 2**6, 'out') // 2**6
        seg_s6 = fetch_labelmap_voxels(*input_seg, box_s6, scale=6)
        mask_s6 = seg_s6.astype(bool)
        logger.info("Encoding segmentation as ranges")
        seg_ranges = runlength_encode_mask_to_ranges(mask_s6, box_s6)

        logger.info("Sampling from ranges")

        if minimum_distance > 0:
            # Sample 4x extra so we can prune some out if necessary.
            points_s6 = sample_points_from_ranges(seg_ranges, 4 * count, rng)
        else:
            points_s6 = sample_points_from_ranges(seg_ranges, count, rng)

        corners_s0 = points_s6 * (2**6)
        points_s0 = rng.integers(corners_s0, corners_s0 + (2**6))

        points = pd.DataFrame(points_s0, columns=[*'zyx'])

        if minimum_distance > 0:
            logger.info(
                f"Pruning close points from {len(points)} segmentation points")
            prune_close_pairs(points, minimum_distance, rng)
            logger.info(
                f"After pruning, points from {len(points)} segmentation points remain"
            )

        points = points.iloc[:count]
        logger.info(f"Returning {len(points)} segmentation points")
        return points