Ejemplo n.º 1
0
def fill_all_holes(cc_labels, progress=False, return_fill_count=False):
    """
  Fills the holes in each connected component and removes components that
  get filled in. The idea is that holes (entirely contained labels or background) 
  are artifacts in cell segmentations. A common example is a nucleus segmented 
  separately from the rest of the cell or errors in a manual segmentation leaving
  a void in a dendrite.

  cc_labels: an image containing connected components with labels smaller than
    the number of voxels in the image.
  progress: Display a progress bar or not.
  return_fill_count: if specified, return a tuple (filled_image, N) where N is
    the number of voxels that were filled in.

  Returns: filled_in_labels
  """
    labels = fastremap.unique(cc_labels)
    labels_set = set(labels)
    labels_set.discard(0)

    all_slices = find_objects(cc_labels)
    pixels_filled = 0

    for label in tqdm(labels, disable=(not progress), desc="Filling Holes"):
        if label not in labels_set:
            continue

        slices = all_slices[label - 1]
        if slices is None:
            continue

        binary_image = (cc_labels[slices] == label)
        binary_image, N = fill_voids.fill(binary_image,
                                          in_place=True,
                                          return_fill_count=True)
        pixels_filled += N
        if N == 0:
            continue

        sub_labels = set(fastremap.unique(cc_labels[slices] * binary_image))
        sub_labels.remove(label)
        labels_set -= sub_labels
        cc_labels[
            slices] = cc_labels[slices] * ~binary_image + label * binary_image

    if return_fill_count:
        return cc_labels, pixels_filled
    return cc_labels
Ejemplo n.º 2
0
 def mask_fragments(self, voxel_num_threshold: int):
     uniq, counts = fastremap.unique(self.array, return_counts=True)
     fragment_ids = uniq[counts <= voxel_num_threshold]
     logging.info(
         f'masking out {len(fragment_ids)} fragments in {len(uniq)} with a percentage of {len(fragment_ids)/len(uniq)}'
     )
     self.array = fastremap.mask(self.array, fragment_ids)
Ejemplo n.º 3
0
def labels(filedata: bytes,
           encoding: str,
           shape=None,
           dtype=None,
           block_size=None,
           background_color: int = 0) -> np.ndarray:
    """
  Extract unique labels from a chunk using
  the most efficient means possible for the
  encoding type.

  Returns: numpy array of unique values
  """
    if filedata is None or len(filedata) == 0:
        return np.zeros((0, ), dtype=dtype)
    elif encoding == "raw":
        img = decode(filedata, encoding, shape, dtype, block_size,
                     background_color)
        return fastremap.unique(img)
    elif encoding == "compressed_segmentation":
        return cseg.labels(filedata,
                           shape=shape[:3],
                           dtype=dtype,
                           block_size=block_size)
    elif encoding == "compresso":
        return compresso.labels(filedata)
    else:
        raise NotImplementedError(
            f"Encoding {encoding} is not supported. Try: raw, compressed_segmentation, or compresso."
        )
Ejemplo n.º 4
0
  def _remove_dust(self, data, dust_threshold):
    if dust_threshold:
      segids, pxct = fastremap.unique(data, return_counts=True)
      dust_segids = [ sid for sid, ct in zip(segids, pxct) if ct < int(dust_threshold) ]
      data = fastremap.mask(data, dust_segids, in_place=True)

    return data
Ejemplo n.º 5
0
 def process_shell(labels, bbox):
   nonlocal all_labels
   nonlocal requested_bbox
   crop_bbox = Bbox.intersection(requested_bbox, bbox)
   crop_bbox -= bbox.minpt
   labels = labels[ crop_bbox.to_slices() ]
   all_labels |= set(fastremap.unique(labels))
Ejemplo n.º 6
0
    def components(self):
        """
    Extract connected components from graph. 
    Useful for ensuring that you're working with a single tree.

    Returns: [ Skeleton, Skeleton, ... ]
    """
        skel = self.clone()
        forest = self._compute_components(skel)

        if len(forest) == 0:
            return []
        elif len(forest) == 1:
            return [skel]

        skeletons = []
        for edge_list in forest:
            edge_list = np.array(edge_list, dtype=np.uint32)
            vert_idx = fastremap.unique(edge_list)

            vert_list = skel.vertices[vert_idx]
            radii = skel.radii[vert_idx]
            vtypes = skel.vertex_types[vert_idx]

            remap = {vid: i for i, vid in enumerate(vert_idx)}
            edge_list = fastremap.remap(edge_list, remap, in_place=True)

            skeletons.append(
                Skeleton(vert_list, edge_list, radii, vtypes, skel.id))

        return skeletons
Ejemplo n.º 7
0
def engage_avocado_protection(cc_labels, all_dbf, remapping,
                              soma_detection_threshold, edtfn, progress):
    orig_cc_labels = np.copy(cc_labels, order='F')

    unchanged = set()
    max_iterations = max(fastremap.unique(cc_labels))

    # This loop handles nested avocados
    # Unless there are deeply nested double avocados,
    # this should complete in 2-3 passes. We limit it
    # to 20 just to make sure this loop terminates no matter what.
    # Avocados aren't the end of the world.
    for _ in tqdm(range(20), disable=(not progress), desc="Avocado Pass"):
        # Note: Divide soma_detection_threshold by a bit more than 2 because the nucleii are going to be
        # about a factor of 2 or less smaller than what we'd expect from a cell. For example,
        # in an avocado I saw, the DBF of the nucleus was 499 when the detection threshold was
        # set to 1100.
        candidates = set(
            fastremap.unique(cc_labels *
                             (all_dbf > soma_detection_threshold / 2.5)))
        candidates -= unchanged
        candidates.discard(0)

        cc_labels, unchanged_this_cycle, changes = engage_avocado_protection_single_pass(
            cc_labels,
            all_dbf,
            candidates=candidates,
            progress=progress,
        )
        unchanged |= unchanged_this_cycle

        if len(changes) == 0:
            break

        all_dbf = edtfn(cc_labels)

    # Downstream logic assumes cc_labels is contigiously numbered
    cc_labels, _ = fastremap.renumber(cc_labels, in_place=True)
    cc_remapping = kimimaro.skeletontricks.get_mapping(orig_cc_labels,
                                                       cc_labels)

    adjusted_remapping = {}
    for new_cc, cc in cc_remapping.items():
        if cc in remapping:
            adjusted_remapping[new_cc] = remapping[cc]

    return cc_labels, all_dbf, adjusted_remapping
Ejemplo n.º 8
0
def engage_avocado_protection_single_pass(cc_labels,
                                          all_dbf,
                                          candidates=None,
                                          progress=False):
    """
  For each candidate, check if there's a fruit around the
  avocado pit roughly from the center (the max EDT).
  """

    if candidates is None:
        candidates = fastremap.unique(cc_labels)

    candidates = [label for label in candidates if label != 0]

    unchanged = set()
    changed = set()

    if len(candidates) == 0:
        return cc_labels, unchanged, changed

    def paint_walls(binimg):
        """
    Ensure that inclusions that touch the wall are handled
    by performing a 2D fill on each wall.
    """
        binimg[:, :, 0] = fill_voids.fill(binimg[:, :, 0])
        binimg[:, :, -1] = fill_voids.fill(binimg[:, :, -1])
        binimg[:, 0, :] = fill_voids.fill(binimg[:, 0, :])
        binimg[:, -1, :] = fill_voids.fill(binimg[:, -1, :])
        binimg[0, :, :] = fill_voids.fill(binimg[0, :, :])
        binimg[-1, :, :] = fill_voids.fill(binimg[-1, :, :])
        return binimg

    remap = {}
    for label in tqdm(candidates,
                      disable=(not progress),
                      desc="Fixing Avocados"):
        binimg = paint_walls(cc_labels == label)  # image of the pit
        coord = argmax(binimg * all_dbf)

        (pit, fruit) = kimimaro.skeletontricks.find_avocado_fruit(
            cc_labels, coord[0], coord[1], coord[2])
        if pit == fruit and pit not in changed:
            unchanged.add(pit)
        else:
            unchanged.discard(pit)
            unchanged.discard(fruit)
            changed.add(pit)
            changed.add(fruit)
            binimg |= (cc_labels == fruit)

        binimg, N = fill_voids.fill(binimg,
                                    in_place=True,
                                    return_fill_count=True)
        cc_labels *= ~binimg
        cc_labels += label * binimg

    return cc_labels, unchanged, changed
Ejemplo n.º 9
0
    def _single_tree_interjoint_paths(self, skeleton, return_indices):
        vertices = skeleton.vertices
        edges = skeleton.edges

        unique_nodes, unique_counts = fastremap.unique(edges,
                                                       return_counts=True)
        terminal_nodes = unique_nodes[unique_counts == 1]
        branch_nodes = set(unique_nodes[unique_counts >= 3])

        critical_points = set(terminal_nodes)
        critical_points.update(branch_nodes)

        tree = defaultdict(set)

        for e1, e2 in edges:
            tree[e1].add(e2)
            tree[e2].add(e1)

        # The below depth first search would be
        # more elegantly implemented as recursion,
        # but it quickly blows the stack, mandating
        # an iterative implementation.

        paths = []

        stack = [terminal_nodes[0]]
        criticals = [terminal_nodes[0]]
        # Saving the path stack is memory intensive
        # There might be a way to do it more linearly
        # via a DFS rather than BFS strategy.
        path_stack = [[]]

        visited = defaultdict(bool)

        while stack:
            node = stack.pop()
            root = criticals.pop()  # "root" is used v. loosely here
            path = path_stack.pop()

            path.append(node)
            visited[node] = True

            if node != root and node in critical_points:
                paths.append(path)
                path = [node]
                root = node

            for child in tree[node]:
                if not visited[child]:
                    stack.append(child)
                    criticals.append(root)
                    path_stack.append(list(path))

        if return_indices:
            return paths

        return [vertices[path] for path in paths]
Ejemplo n.º 10
0
  def agglomerate_cutout(self, img, timestamp=None, stop_layer=None):
    """Remap a graphene volume to its latest root ids. This creates a flat segmentation."""
    if np.all(img == self.image.background_color) or stop_layer == 1:
      return img

    labels = fastremap.unique(img)
    if labels.size and labels[0] == 0:
      labels = labels[1:]

    roots = self.get_roots(labels, timestamp=timestamp, binary=True, stop_layer=stop_layer)
    mapping = { segid: root for segid, root in zip(labels, roots) }
    return fastremap.remap(img, mapping, preserve_missing_labels=True, in_place=True)
Ejemplo n.º 11
0
    def process_renumber(img3d, bbox):
      nonlocal N
      nonlocal lock 
      nonlocal remap
      nonlocal renderbuffer
      img_labels = fastremap.unique(img3d)
      with lock:
        for lbl in img_labels:
          if lbl not in remap:
            remap[lbl] = N
            N += 1
        if N > np.iinfo(renderbuffer.dtype).max:
          renderbuffer = fastremap.refit(renderbuffer, value=N, increase_only=True)

        fastremap.remap(img3d, remap, in_place=True)
        shade(renderbuffer, requested_bbox, img3d, bbox)
Ejemplo n.º 12
0
    def _compute_components(self, skel):
        if skel.edges.size == 0:
            return skel, []

        index = defaultdict(set)
        visited = defaultdict(bool)
        for e1, e2 in skel.edges:
            index[e1].add(e2)
            index[e2].add(e1)

        def extract_component(start):
            edge_list = []
            stack = [start]
            parents = [-1]

            while stack:
                node = stack.pop()
                parent = parents.pop()

                if node < parent:
                    edge_list.append((node, parent))
                else:
                    edge_list.append((parent, node))

                if visited[node]:
                    continue

                visited[node] = True

                for child in index[node]:
                    stack.append(child)
                    parents.append(node)

            return np.unique(edge_list[1:], axis=0)

        forest = []
        for edge in fastremap.unique(skel.edges.flatten()):
            if visited[edge]:
                continue

            forest.append(extract_component(edge))

        return forest
Ejemplo n.º 13
0
    def get_roots(self, segids, timestamp=None, binary=True, stop_layer=None):
        """
    Get the root ids for these labels.

    segids: (int or iterable) one or more segids to remap
    timestamp: get the roots from this date and time
      formats accepted:
        int: unix timestamp
        datetime: self explainatory
        string: ISO 8601 date
    binary: if true, send and receive segids as a 
      binary stream else, use JSON. The difference can
      be a 2x difference in bandwidth used.
    stop_layer: (int) if specified, return the lowest parent at or above 
      that layer. If not specified, go all the way to the root id. 
        Layer 1: Watershed
        Layer 2: Within-Chunk Agglomeration
        Layer 2+: Between chunk interconnections (skip connections possible)
    """
        segids = toiter(segids)
        input_segids = np.array(segids, dtype=self.meta.dtype)

        if input_segids.size == 0:
            return np.array([], dtype=self.meta.dtype)

        segids = fastremap.unique(input_segids)

        base_remap = {0: 0}
        # skip ids that are already root IDs
        for segid in segids:
            layer_id = self.meta.decode_layer_id(segid)
            if layer_id in (stop_layer, self.meta.n_layers):
                base_remap[segid] = segid

        segids = np.array(
            [segid for segid in segids if segid not in base_remap],
            dtype=self.meta.dtype)

        timestamp = to_unix_time(timestamp)

        if stop_layer is not None:
            stop_layer = int(stop_layer)
            if stop_layer < 1 or stop_layer > self.meta.n_layers:
                raise ValueError(
                    "stop_layer ({}) must be between 1 and {} inclusive.".
                    format(stop_layer, self.meta.n_layers))

        if self.meta.supports_api('v1'):
            roots = self._get_roots_v1(segids, timestamp, binary, stop_layer)
        elif self.meta.supports_api('1.0'):
            roots = self._get_roots_legacy(segids, timestamp)
        else:
            raise exceptions.UnsupportedGrapheneAPIVersionError(
              "{} is not a supported API version. Supported versions: ".format(self.meta.api_version) \
              + ", ".join([ str(_) for _ in self.meta.supported_api_versions ])
            )

        for segid, root_id in zip(segids, roots):
            base_remap[segid] = root_id

        return fastremap.remap(input_segids, base_remap)
Ejemplo n.º 14
0
def _mesh_from_voxels_chunked(voxels,
                              spacing=(1, 1, 1),
                              step_size=1,
                              chunk_size=200,
                              pad_chunks=True,
                              merge_fragments=True,
                              progress=True):
    """Generate mesh from voxels in chunks using marching cubes.

    Potentially faster and much more memory efficient but might introduce
    internal faces and/or holes.

    Parameters
    ----------
    voxels :        np.array
                    Voxel coordindates. Array of (N, 3) XYZ indices.
    spacing :       np.array
                    (3, ) array with x, y, z voxel size.
    step_size :     int, optional
                    Step size for marching cube algorithm.
                    Higher values = faster but coarser.
    chunk_size :    int
                    Size of the cubes in voxels in which to process the data.
                    The bigger the chunks, the more memory is used but there is
                    less chance of errors in the mesh.
    pad_chunks :    bool
                    If True, will pad each chunk. This helps making meshes
                    watertight but may introduce internal faces when merging
                    mesh fragments.
    merge_fragments :  bool
                    If True, will attempt to merge fragments at the chunk
                    boundaries.

    Returns
    -------
    trimesh.Trimesh

    """
    # Use marching cubes to create surface model
    # (newer versions of skimage have a "marching cubes" function and
    # the marching_cubes_lewiner is deprecreated)
    marching_cubes = getattr(measure, 'marching_cubes',
                             getattr(measure, 'marching_cubes_lewiner', None))

    # Strip the voxels
    offset = voxels.min(axis=0)
    voxels = voxels - offset

    # Map voxels to chunks
    chunks = (voxels / chunk_size).astype(int)

    # Find the largest index
    # max_ix = chunks.max()
    # base = math.ceil(np.sqrt(max_ix))
    base = 16  # 2**16=65,536 max index - this should be sufficient for chunks

    # Now we encode the indices (x, y, z) chunk indices as packed integer:
    # Each (xyz) chunk is encoded as single integer which speeds things up a lot
    # For example with base = 16, chunk (1, 2, 3) becomes:
    # N = 2 ** 16 = 65,536
    # (1 * N ** 2) + (2 * N) + 3 = 4,295,098,371
    # This obviously only works as long as we can still squeeze the chunks into
    # 64bit integers but that should work out even at scale 0. If this ever
    # becomes an issue, we could start using strings instead. For now, base 16
    # should be enough.
    chunks_packed = pack_array(chunks, base=base)

    # Find unique chunks
    chunks_unique = unique(chunks_packed)

    # For each chunk also find voxels that are directly adjacent (in plus direction)
    # This makes it so that each voxel can belong to multiple chunks
    # What we want to get is an (4, N) array where for each voxel we have its
    # "original" chunk index and its chunks when offset by -1 along each axis.
    # original chunk -> [[1, 1, 1, 0, 0, 0, 0],
    # offset x by -1 ->  [1, 1, 1, 1, 0, 1, 0],
    # offset y by -1 ->  [1, 1, 1, 0, 1, 0, 0]
    # offset z by -1 ->  [1, 1, 1, 0, 0, 0, 1]]
    # The simple numbers (0, 1) in this example will obvs be our packed integers
    # Later on we can then ask: "find me all voxels in chunk 0, 1, etc."
    voxel2chunk = np.full((4, len(voxels)), fill_value=-1, dtype=int)
    voxel2chunk[-1, :] = chunks_packed

    # Find offset chunks along each axis
    for k in range(3):
        # Offset chunks and pack
        chunks_offset = voxels.copy()
        chunks_offset[:, k] -= 1
        chunks_offset = (chunks_offset / chunk_size).astype(int)
        chunks_offset = pack_array(chunks_offset, base=base)

        voxel2chunk[k] = chunks_offset

    # Unpack the unique chunks
    chunks_unique_unpacked = unpack_array(chunks_unique, base=base)

    # Generate the fragments
    fragments = []
    end_chunks = chunks_unique_unpacked.max(axis=0)
    pad = np.array([[1, 1], [1, 1], [1, 1]])
    for i, (ch, ix) in config.tqdm(enumerate(zip(chunks_unique, chunks_unique_unpacked)),
                                   total=len(chunks_unique),
                                   disable=not progress,
                                   leave=False,
                                   desc='Meshing'):
        # Pad the matrices only for the first and last chunks along each axis
        if not pad_chunks:
            pad = np.array([[0, 0], [0, 0], [0, 0]])
            for k in range(3):
                if ix[k] == 0:
                    pad[k][0] = 1
                if ix[k] == end_chunks[k]:
                    pad[k][1] = 1

        # Get voxels in this chunk
        this_vx = voxels[np.any(voxel2chunk == ch, axis=0)]

        # If only a single voxel, skip.
        if this_vx.shape[0] <= 1:
            continue

        # Turn voxels into matrix with given padding
        mat, chunk_offset = _voxels_to_matrix(this_vx, pad=pad.tolist())

        # Marching cubes needs at least a (2, 2, 2) matrix
        # We could in theory make this work by adding padding
        # but we probably don't loose much if we skip them.
        # There is a chance that we might introduce holes in the mesh
        # though
        if any([s < 2 for s in mat.shape]):
            continue

        # Run the actual marching cubes
        v, f, _, _ = marching_cubes(mat,
                                    level=.5,
                                    step_size=step_size,
                                    allow_degenerate=False,
                                    gradient_direction='ascent',
                                    spacing=(1, 1, 1))

        # Remove the padding (if any)
        v -= pad[:, 0]
        # Add chunk offset
        v += chunk_offset

        fragments.append((v, f))

    # Combine into a single mesh
    all_verts = []
    all_faces = []
    verts_offset = 0
    for frag in fragments:
        all_verts.append(frag[0])
        all_faces.append(frag[1] + verts_offset)
        verts_offset += len(frag[0])

    all_verts = (np.concatenate(all_verts, axis=0) + offset) * spacing
    all_faces = np.concatenate(all_faces, axis=0)

    # Make trimesh
    m = tm.Trimesh(all_verts, all_faces)

    # Deduplicate chunk boundaries
    # This is not necessarily the cleanest solution but it should do the
    # trick most of the time
    if merge_fragments:
        m.merge_vertices(digits_vertex=1)

    return m
Ejemplo n.º 15
0
def by_teasar(mesh, inv_dist, min_length=None, root=None, progress=True):
    """Skeletonize a mesh mesh using the TEASAR algorithm [1].

    This algorithm finds the longest path from a root vertex, invalidates all
    vertices that are within `inv_dist`. Then picks the second longest (and
    still valid) path and does the same. Rinse & repeat until all vertices have
    been invalidated. It's fast + works very well with tubular meshes, and with
    `inv_dist` you have control over the level of detail. Note that by its
    nature the skeleton will be exactly on the surface of the mesh.

    Based on the implementation by Sven Dorkenwald, Casey Schneider-Mizell and
    Forrest Collman in `meshparty` (https://github.com/sdorkenw/MeshParty).

    Parameters
    ----------
    mesh :          mesh obj
                    The mesh to be skeletonize. Can an object that has
                    ``.vertices`` and ``.faces`` properties  (e.g. a
                    trimesh.Trimesh) or a tuple ``(vertices, faces)`` or a
                    dictionary ``{'vertices': vertices, 'faces': faces}``.
    inv_dist :      int | float
                    Distance along the mesh used for invalidation of vertices.
                    This controls how detailed (or noisy) the skeleton will be.
    min_length :    float, optional
                    If provided, will skip any branch that is shorter than
                    `min_length`. Use this to get rid of noise but note that
                    it will lead to vertices not being mapped to skeleton nodes.
                    Such vertices will show up with index -1 in
                    `Skeleton.mesh_map`.
    root :          int, optional
                    Vertex ID of a root. If not provided will use ``0``.
    progress :      bool, optional
                    If True, will show progress bar.

    Returns
    -------
    skeletor.Skeleton
                    Holds results of the skeletonization and enables quick
                    visualization.

    References
    ----------
    [1] Sato, M., Bitter, I., Bender, M. A., Kaufman, A. E., & Nakajima, M.
        (n.d.). TEASAR: tree-structure extraction algorithm for accurate and
        robust skeletons. In Proceedings the Eighth Pacific Conference on
        Computer Graphics and Applications. IEEE Comput. Soc.
        https://doi.org/10.1109/pccga.2000.883951

    """
    mesh = make_trimesh(mesh, validate=False)

    # Generate Graph (must be undirected)
    G = ig.Graph(edges=mesh.edges_unique, directed=False)
    G.es['weight'] = mesh.edges_unique_length

    if not root:
        root = 0

    edges = np.array([], dtype=np.int64)
    mesh_map = np.full(mesh.vertices.shape[0], fill_value=-1)

    with tqdm(desc='Invalidating',
              total=len(G.vs),
              disable=not progress,
              leave=False) as pbar:
        for cc in sorted(G.clusters(), key=len, reverse=True):
            # Make a subgraph for this connected component
            SG = G.subgraph(cc)
            cc = np.array(cc)

            # Find root within subgraph
            if root in cc:
                this_root = np.where(cc == root)[0][0]
            else:
                this_root = 0

            # Get the sparse adjacency matrix of the subgraph
            sp = SG.get_adjacency_sparse('weight')

            # Get lengths of paths to all nodes from root
            paths = SG.shortest_paths(this_root,
                                      target=None,
                                      weights='weight',
                                      mode='ALL')[0]
            paths = np.array(paths)

            # Prep array for invalidation
            valid = ~np.zeros(paths.shape).astype(bool)
            invalidated = 0

            while np.any(valid):
                # Find the farthest point
                farthest = np.argmax(paths)

                # Get path from root to farthest point
                path = SG.get_shortest_paths(this_root,
                                             farthest,
                                             weights='weight',
                                             mode='ALL')[0]

                # Get IDs of edges along the path
                eids = SG.get_eids(path=path, directed=False)

                # Stop if farthest point is closer than min_length
                add = True
                if min_length:
                    # This should only be distance to the first branchpoint
                    # from the tip since we set other weights to zero
                    le = sum(SG.es[eids].get_attribute_values('weight'))
                    if le < min_length:
                        add = False

                if add:
                    # Add these new edges
                    new_edges = np.vstack((cc[path[:-1]], cc[path[1:]])).T
                    edges = np.append(edges, new_edges).reshape(-1, 2)

                # Invalidate points in the path
                valid[path] = False
                paths[path] = 0

                # Must set weights along path to 0 so that this path is
                # taken again in future iterations
                SG.es[eids]['weight'] = 0

                # Get all nodes within `inv_dist` to this path
                # Note: can we somehow only include still valid nodes to speed
                # things up?
                dist, _, sources = dijkstra(sp,
                                            directed=False,
                                            indices=path,
                                            limit=inv_dist,
                                            min_only=True,
                                            return_predecessors=True)

                # Invalidate
                in_dist = dist <= inv_dist
                to_invalidate = np.where(in_dist)[0]
                valid[to_invalidate] = False
                paths[to_invalidate] = 0

                # Update mesh vertex to skeleton node map
                mesh_map[cc[in_dist]] = cc[sources[in_dist]]

                pbar.update((~valid).sum() - invalidated)
                invalidated = (~valid).sum()

    # Make unique edges (paths will have overlapped!)
    edges = unique(edges, axis=0)

    # Create a directed acyclic and hierarchical graph
    G_nx = edges_to_graph(edges=edges[:, [1, 0]],
                          fix_tree=True,
                          fix_edges=False,
                          weight=False)

    # Generate the SWC table
    swc, new_ids = make_swc(G_nx, coords=mesh.vertices, reindex=True)

    # Update vertex to node ID map
    mesh_map = np.array([new_ids.get(n, -1) for n in mesh_map])

    return Skeleton(swc=swc, mesh=mesh, mesh_map=mesh_map, method='teasar')
Ejemplo n.º 16
0
 def agglomerate_cutout(self, img, timestamp=None, stop_layer=None):
   """Remap a graphene volume to its latest root ids. This creates a flat segmentation."""
   labels = fastremap.unique(img)
   roots = self.get_roots(labels, timestamp=timestamp, binary=True, stop_layer=stop_layer)
   mapping = { segid: root for segid, root in zip(labels, roots) }
   return fastremap.remap(img, mapping, preserve_missing_labels=True, in_place=True)
Ejemplo n.º 17
0
def unique_sharded(
  requested_bbox, mip,
  meta, cache, lru, spec,
  compress, progress,
  fill_missing, background_color
):
  """
  Accumulate all unique labels within the requested
  bounding box.
  """
  full_bbox = requested_bbox.expand_to_chunk_size(
    meta.chunk_size(mip), offset=meta.voxel_offset(mip)
  )
  full_bbox = Bbox.clamp(full_bbox, meta.bounds(mip))
  core_bbox = requested_bbox.shrink_to_chunk_size(
    meta.chunk_size(mip), offset=meta.voxel_offset(mip)
  )
  core_bbox = Bbox.clamp(core_bbox, meta.bounds(mip))

  compress_cache = should_compress(meta.encoding(mip), compress, cache, iscache=True)

  chunk_size = meta.chunk_size(mip)
  grid_size = np.ceil(meta.bounds(mip).size3() / chunk_size).astype(np.uint32)

  reader = sharding.ShardReader(meta, cache, spec)
  bounds = meta.bounds(mip)

  all_gpts = list(gridpoints(full_bbox, bounds, chunk_size))
  core_gpts = list(gridpoints(core_bbox, bounds, chunk_size))

  code_map = {}
  all_morton_codes = compressed_morton_code(all_gpts, grid_size)
  for gridpoint, morton_code in zip(all_gpts, all_morton_codes):
    cutout_bbox = Bbox(
      bounds.minpt + gridpoint * chunk_size,
      min2(bounds.minpt + (gridpoint + 1) * chunk_size, bounds.maxpt)
    )
    code_map[morton_code] = cutout_bbox

  lru_codes = set([ code for code in all_morton_codes if code in lru ])
  lru_chunkdata = { code: lru[code] for code in lru_codes }
  
  core_morton_codes = set(compressed_morton_code(core_gpts, grid_size))
  io_core_morton_codes = core_morton_codes - lru_codes
  lru_core_morton_codes = core_morton_codes.intersection(lru_codes)

  def iterate_core():
    for mcs in sip(io_core_morton_codes, 10000):
      core_chunkdata = reader.get_data(mcs, meta.key(mip), progress=progress)
      for zcode, chunkdata in core_chunkdata.items():
        yield (zcode, chunkdata)
        lru[zcode] = chunkdata
    for code in lru_core_morton_codes:
      yield (code, lru_chunkdata[code])
      del lru_chunkdata[code]

  all_labels = set()
  for zcode, chunkdata in iterate_core():
    cutout_bbox = code_map[zcode]
    labels = decode_unique(
      meta, cutout_bbox, 
      chunkdata, fill_missing, mip,
      background_color=background_color
    )
    all_labels |= set(labels)

  shell_morton_codes = set(all_morton_codes) - set(core_morton_codes)
  io_shell_morton_codes = shell_morton_codes - lru_codes
  lru_shell_morton_codes = shell_morton_codes.intersection(lru_codes)

  def iterate_shell():
    shell_chunkdata = reader.get_data(io_shell_morton_codes, meta.key(mip), progress=progress)
    for zcode, chunkdata in shell_chunkdata.items():
      yield (zcode, chunkdata)
      lru[zcode] = chunkdata
    for code in lru_shell_morton_codes:
      yield (code, lru_chunkdata[code])
      del lru_chunkdata[code]    

  for zcode, chunkdata in iterate_shell():
    cutout_bbox = code_map[zcode]
    labels = decode(
      meta, cutout_bbox, 
      chunkdata, fill_missing, mip,
      background_color=background_color
    )
    crop_bbox = Bbox.intersection(requested_bbox, cutout_bbox)
    crop_bbox -= cutout_bbox.minpt
    labels = fastremap.unique(labels[ crop_bbox.to_slices() ])
    all_labels |= set(labels)

  return all_labels
Ejemplo n.º 18
0
def edges_to_graph(edges,
                   nodes=None,
                   vertices=None,
                   fix_edges=True,
                   fix_tree=True,
                   drop_disconnected=False,
                   weight=True):
    """Create networkx Graph from edge list."""
    if not fix_tree and not fix_edges:
        # If no fixing required, we need to go for a directed graph straight
        # away
        G = nx.DiGraph()
    else:
        G = nx.Graph()

    if fix_edges:
        # Drop self-loops
        edges = edges[edges[:, 0] != edges[:, 1]]

        # Make sure we don't have a->b and b<-a edges
        edges = unique(np.sort(edges, axis=1), axis=0)

    # Extract nodes from edges if not explicitly provided
    if isinstance(nodes, type(None)):
        nodes = unique(edges.flatten())

    if isinstance(vertices, np.ndarray):
        coords = vertices[nodes]
        add = [(n, {
            'x': co[0],
            'y': co[1],
            'z': co[2]
        }) for n, co in zip(nodes, coords)]
    else:
        add = nodes
    G.add_nodes_from(add)

    G.add_edges_from([(e[0], e[1]) for e in edges])

    if fix_tree:
        # First remove cycles
        while True:
            try:
                # Find cycle
                cycle = nx.find_cycle(G)
            except nx.exception.NetworkXNoCycle:
                break
            except BaseException:
                raise

            # Sort by degree
            cycle = sorted(cycle, key=lambda x: G.degree[x[0]])

            # Remove the edge with the lowest degree
            G.remove_edge(cycle[0][0], cycle[0][1])

        # Now make sure this is a DAG, i.e. that all edges point in the same direction
        new_edges = []
        for c in nx.connected_components(G.to_undirected()):
            sg = nx.subgraph(G, c)
            # Pick a random root
            r = list(sg.nodes)[0]

            # Generate parent->child dictionary by graph traversal
            this_lop = nx.predecessor(sg, r)

            # Note that we assign -1 as root's parent
            new_edges += [(k, v[0]) for k, v in this_lop.items() if v]

        # We need a directed Graph for this as otherwise the child -> parent
        # order in the edges might get lost
        G2 = nx.DiGraph()
        G2.add_nodes_from(G.nodes)
        G2.add_edges_from(new_edges)
        G = G2

    if drop_disconnected:
        # Array of degrees [[node_id, degree], [....], ...]
        deg = np.array(G.degree)
        G.remove_nodes_from(deg[deg[:, 1] == 0][:, 0])

    if weight and isinstance(vertices, np.ndarray):
        final_edges = np.array(G.edges)
        vec = vertices[final_edges[:, 0]] - vertices[final_edges[:, 1]]
        weights = np.sqrt(np.sum(vec**2, axis=1))
        G.remove_edges_from(list(G.edges))
        G.add_weighted_edges_from([(e[0], e[1], w)
                                   for e, w in zip(final_edges, weights)])

    return G
Ejemplo n.º 19
0
def skeletonize(all_labels,
                teasar_params=DEFAULT_TEASAR_PARAMS,
                anisotropy=(1, 1, 1),
                object_ids=None,
                dust_threshold=1000,
                cc_safety_factor=1,
                progress=False,
                fix_branching=True,
                in_place=False,
                fix_borders=True,
                parallel=1,
                parallel_chunk_size=100,
                extra_targets_before=[],
                extra_targets_after=[],
                fill_holes=False,
                fix_avocados=False):
    """
  Skeletonize all non-zero labels in a given 2D or 3D image.

  Required:
    all_labels: a 2D or 3D numpy array of integer type (signed or unsigned) 

  Optional:
    anisotropy: the physical dimensions of each axis (e.g. 4nm x 4nm x 40nm)
    object_ids: If not none, zero out all labels other than those specified here.
    teasar_params: {
      scale: during the "rolling ball" invalidation phase, multiply 
          the DBF value by this.
      const: during the "rolling ball" invalidation phase, this 
          is the minimum radius in chosen physical units (i.e. nm).
      soma_detection_threshold: if object has a DBF value larger than this, 
          root will be placed at largest DBF value and special one time invalidation
          will be run over that root location (see soma_invalidation scale)
          expressed in chosen physical units (i.e. nm) 
      pdrf_scale: scale factor in front of dbf, used to weight dbf over euclidean distance (higher to pay more attention to dbf) (default 5000)
      pdrf_exponent: exponent in dbf formula on distance from edge, faster if factor of 2 (default 16)
      soma_invalidation_scale: the 'scale' factor used in the one time soma root invalidation (default .5)
      soma_invalidation_const: the 'const' factor used in the one time soma root invalidation (default 0)
                             (units in chosen physical units (i.e. nm))
      max_paths: max paths to trace on a single object. Moves onto the next object after this point.
    }
    dust_threshold: don't bother skeletonizing connected components smaller than
      this many voxels.
    fill_holes: preemptively run a void filling algorithm on all connected
      components and delete labels that get filled in. This can improve the
      quality of the reconstruction if holes in the shapes are artifacts introduced
      by the segmentation pipeline. This option incurs moderate overhead.

      WARNING: THIS WILL REMOVE INPUT LABELS THAT ARE DEEMED TO BE HOLES.

    cc_safety_factor: Value between 0 and 1 that scales the size of the 
      disjoint set maps in connected_components. 1 is guaranteed to work,
      but is probably excessive and corresponds to every pixel being a different
      label. Use smaller values to save some memory.

    extra_targets_before: List of x,y,z voxel coordinates that will all 
      be traced to from the root regardless of whether those points have 
      been invalidated. These targets will be applied BEFORE the regular
      target selection algorithm is run.      

      e.g. [ (x,y,z), (x,y,z) ]

    extra_targets_after: Same as extra_targets_before but the additional
      targets will be applied AFTER the usual algorithm runs.

    progress: if true, display a progress bar
    fix_branching: When enabled, zero the edge weights by of previously 
      traced paths. This causes branch points to occur closer to 
      the actual path divergence. However, there is a performance penalty
      associated with this as dijkstra's algorithm is computed once per a path
      rather than once per a skeleton.
    in_place: if true, allow input labels to be modified to reduce
      memory usage and possibly improve performance.
    fix_borders: ensure that segments touching the border place a 
      skeleton endpoint in a predictable place to make merging 
      adjacent chunks easier.
    fix_avocados: If nuclei are segmented seperately from somata
      then we can try to detect and fix this issue.
    parallel: number of subprocesses to use.
      <= 0: Use multiprocessing.count_cpu() 
         1: Only use the main process.
      >= 2: Use this number of subprocesses.
    parallel_chunk_size: default number of skeletons to 
      submit to each parallel process before returning results,
      updating the progress bar, and submitting a new task set. 
      Setting this number too low results in excess IPC overhead,
      and setting it too high can result in task starvation towards
      the end of a job and infrequent progress bar updates. If the
      chunk size is set higher than num tasks // parallel, that number
      is used instead.

  Returns: { $segid: cloudvolume.PrecomputedSkeleton, ... }
  """

    anisotropy = np.array(anisotropy, dtype=np.float32)

    all_labels = format_labels(all_labels, in_place=in_place)
    all_labels = apply_object_mask(all_labels, object_ids)

    if all_labels.size <= dust_threshold:
        return {}

    minlabel, maxlabel = fastremap.minmax(all_labels)

    if minlabel == 0 and maxlabel == 0:
        return {}

    cc_labels, remapping = compute_cc_labels(all_labels, cc_safety_factor)
    del all_labels

    if fill_holes:
        cc_labels = fill_all_holes(cc_labels, progress)

    extra_targets_before = points_to_labels(extra_targets_before, cc_labels)
    extra_targets_after = points_to_labels(extra_targets_after, cc_labels)

    def edtfn(labels):
        return edt.edt(
            labels,
            anisotropy=anisotropy,
            black_border=(minlabel == maxlabel),
            order='F',
            parallel=parallel,
        )

    all_dbf = edtfn(cc_labels)

    if fix_avocados:
        cc_labels, all_dbf, remapping = engage_avocado_protection(
            cc_labels,
            all_dbf,
            remapping,
            soma_detection_threshold=teasar_params.get(
                'soma_detection_threshold', 0),
            edtfn=edtfn,
            progress=progress,
        )

    cc_segids, pxct = fastremap.unique(cc_labels, return_counts=True)
    cc_segids = [
        sid for sid, ct in zip(cc_segids, pxct)
        if ct > dust_threshold and sid != 0
    ]

    all_slices = find_objects(cc_labels)

    border_targets = defaultdict(list)
    if fix_borders:
        border_targets = compute_border_targets(cc_labels, anisotropy)

    print_quotes(parallel)  # easter egg

    if parallel <= 0:
        parallel = mp.cpu_count()

    if parallel == 1:
        return skeletonize_subset(all_dbf, cc_labels, remapping, teasar_params,
                                  anisotropy, all_slices, border_targets,
                                  extra_targets_before, extra_targets_after,
                                  progress, fix_borders, fix_branching,
                                  cc_segids)
    else:
        # The following section can't be moved into
        # skeletonize parallel because then all_dbf
        # and cc_labels can't be deleted to save memory.
        suffix = uuid.uuid1().hex

        dbf_shm_location = 'kimimaro-shm-dbf-' + suffix
        cc_shm_location = 'kimimaro-shm-cc-labels-' + suffix

        dbf_mmap, all_dbf_shm = shm.ndarray(all_dbf.shape,
                                            all_dbf.dtype,
                                            dbf_shm_location,
                                            order='F')
        cc_mmap, cc_labels_shm = shm.ndarray(cc_labels.shape,
                                             cc_labels.dtype,
                                             cc_shm_location,
                                             order='F')
        all_dbf_shm[:] = all_dbf
        cc_labels_shm[:] = cc_labels
        del all_dbf
        del cc_labels

        skeletons = skeletonize_parallel(
            all_dbf_shm, dbf_shm_location, cc_labels_shm, cc_shm_location,
            remapping, teasar_params, anisotropy, all_slices, border_targets,
            extra_targets_before, extra_targets_after, progress, fix_borders,
            fix_branching, cc_segids, parallel, parallel_chunk_size)

        dbf_mmap.close()
        cc_mmap.close()

        return skeletons
Ejemplo n.º 20
0
def get_masks(p, iscell=None, rpad=20, flows=None, threshold=0.4, use_gpu=False, device=None):
    """ create masks using pixel convergence after running dynamics
    
    Makes a histogram of final pixel locations p, initializes masks 
    at peaks of histogram and extends the masks from the peaks so that
    they include all pixels with more than 2 final pixels p. Discards 
    masks with flow errors greater than the threshold. 
    Parameters
    ----------------
    p: float32, 3D or 4D array
        final locations of each pixel after dynamics,
        size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
    iscell: bool, 2D or 3D array
        if iscell is not None, set pixels that are 
        iscell False to stay in their original location.
    rpad: int (optional, default 20)
        histogram edge padding
    threshold: float (optional, default 0.4)
        masks with flow error greater than threshold are discarded 
        (if flows is not None)
    flows: float, 3D or 4D array (optional, default None)
        flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. If flows
        is not None, then masks with inconsistent flows are removed using 
        `remove_bad_flow_masks`.
    Returns
    ---------------
    M0: int, 2D or 3D array
        masks with inconsistent flow masks removed, 
        0=NO masks; 1,2,...=mask labels,
        size [Ly x Lx] or [Lz x Ly x Lx]
    
    """
    
    pflows = []
    edges = []
    shape0 = p.shape[1:]
    dims = len(p)
    if iscell is not None:
        if dims==3:
            inds = np.meshgrid(np.arange(shape0[0]), np.arange(shape0[1]),
                np.arange(shape0[2]), indexing='ij')
        elif dims==2:
            inds = np.meshgrid(np.arange(shape0[0]), np.arange(shape0[1]),
                     indexing='ij')
        for i in range(dims):
            p[i, ~iscell] = inds[i][~iscell]

    for i in range(dims):
        pflows.append(p[i].flatten().astype('int32'))
        edges.append(np.arange(-.5-rpad, shape0[i]+.5+rpad, 1))

    h,_ = np.histogramdd(tuple(pflows), bins=edges)
    hmax = h.copy()
    for i in range(dims):
        hmax = maximum_filter1d(hmax, 5, axis=i)

    seeds = np.nonzero(np.logical_and(h-hmax>-1e-6, h>10))
    Nmax = h[seeds]
    isort = np.argsort(Nmax)[::-1]
    for s in seeds:
        s = s[isort]

    pix = list(np.array(seeds).T)

    shape = h.shape
    if dims==3:
        expand = np.nonzero(np.ones((3,3,3)))
    else:
        expand = np.nonzero(np.ones((3,3)))
    for e in expand:
        e = np.expand_dims(e,1)

    for iter in range(5):
        for k in range(len(pix)):
            if iter==0:
                pix[k] = list(pix[k])
            newpix = []
            iin = []
            for i,e in enumerate(expand):
                epix = e[:,np.newaxis] + np.expand_dims(pix[k][i], 0) - 1
                epix = epix.flatten()
                iin.append(np.logical_and(epix>=0, epix<shape[i]))
                newpix.append(epix)
            iin = np.all(tuple(iin), axis=0)
            for p in newpix:
                p = p[iin]
            newpix = tuple(newpix)
            igood = h[newpix]>2
            for i in range(dims):
                pix[k][i] = newpix[i][igood]
            if iter==4:
                pix[k] = tuple(pix[k])
    
    M = np.zeros(h.shape, np.uint32)
    for k in range(len(pix)):
        M[pix[k]] = 1+k
        
    for i in range(dims):
        pflows[i] = pflows[i] + rpad
    M0 = M[tuple(pflows)]

    # remove big masks
    uniq, counts = fastremap.unique(M0, return_counts=True)
    big = np.prod(shape0) * 0.4
    bigc = uniq[counts > big]
    if len(bigc) > 0 and (len(bigc)>1 or bigc[0]!=0):
        M0 = fastremap.mask(M0, bigc)
    fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
    M0 = np.reshape(M0, shape0)
    return M0
Ejemplo n.º 21
0
def make_swc(x, coords, reindex=True, validate=True):
    """Generate SWC table.

    Parameters
    ----------
    x :         numpy.ndarray | networkx.Graph | networkx.DiGraph
                Data to generate SWC from. Can be::

                    - (N, 2) array of child->parent edges
                    - networkX graph

    coords :    trimesh.Trimesh | np.ndarray of vertices
                Coordinates of nodes in ``x``.
    reindex :   bool
                If True, will re-index node IDs such that parent nodes always
                have a lower node ID than their childs. This is a requirement
                for the SWC format. Will also return a dictionary mapping
                original to re-indexed node IDs.
    validate :  bool
                If True, will check check if SWC table is valid and raise an
                exception if issues are found.

    Returns
    -------
    swc :       pandas.DataFrame
    new_ids :   dict
                If ``reindex=True`` will also return a map for original to
                re-indexed node IDs.

    """
    assert isinstance(coords, (tm.Trimesh, np.ndarray))

    if isinstance(x, np.ndarray):
        edges = x
    elif isinstance(x, (nx.Graph, nx.DiGraph)):
        edges = np.array(x.edges)
    else:
        raise TypeError(f'Expected array or Graph, got "{type(x)}"')

    # Make sure edges are unique
    edges = unique(edges, axis=0)

    # Need to convert to None if empty - otherwise DataFrame creation acts up
    if len(edges) == 0:
        edges = None

    # Generate node table (do NOT remove the explicit dtype)
    swc = pd.DataFrame(edges, columns=['node_id', 'parent_id'], dtype=int)

    # See if we need to add manually add rows for root node(s)
    miss = swc.parent_id.unique()
    miss = miss[~np.isin(miss, swc.node_id.values)]
    miss = miss[miss > -1]

    # Must not use any() here because we might get miss=[0]
    if len(miss):
        roots = pd.DataFrame([[n, -1] for n in miss], columns=swc.columns)
        swc = pd.concat([swc, roots], axis=0)

    # See if we need to add any disconnected nodes
    if isinstance(x, (nx.Graph, nx.DiGraph)):
        miss = set(x.nodes) - set(swc.node_id.values) - set([-1])
        if miss:
            disc = pd.DataFrame([[n, -1] for n in miss], columns=swc.columns)
            swc = pd.concat([swc, disc], axis=0)

    if isinstance(coords, tm.Trimesh):
        coords = coords.vertices

    if not swc.empty:
        # Map x/y/z coordinates
        swc['x'] = coords[swc.node_id, 0]
        swc['y'] = coords[swc.node_id, 1]
        swc['z'] = coords[swc.node_id, 2]
    else:
        swc['x'] = swc['y'] = swc['z'] = None

    # Placeholder radius
    swc['radius'] = None

    if reindex:
        _, new_ids = reindex_swc(swc, inplace=True)
    else:
        swc = swc.sort_values('parent_id').reset_index(drop=True)

    if validate:
        # Check if any node has multiple parents
        if any(swc.node_id.duplicated()):
            raise ValueError('Nodes with multiple parents found.')

    if reindex:
        return swc, new_ids

    return swc
Ejemplo n.º 22
0
def edges_to_graph(edges,
                   nodes=None,
                   vertices=None,
                   fix_edges=True,
                   fix_tree=True,
                   drop_disconnected=False,
                   weight=True,
                   radii=None):
    """Create networkx Graph from edge list.

    Parameters
    ----------
    edges :         (N, 2) array
    nodes :         (M, ) array, optional
                    Node IDs. Should be provided in case of isolated nodes not
                    part of the edge list.
    vertices :      (M, 3) array, optional
                    X/Y/Z locations of nodes.
    fix_edges :     bool
                    If True (recommended!) will drop self-loops and remove
                    recurrent edges.
    fix_tree :      bool | "length" | "radius" | "degree"
                    If not False (recommended!) will fix the tree by removing
                    cycles. This is done using a minimum-spanning-tree or a
                    breadth-first search (see below). To improve this we can use
                    weights to increase the probability that cuts are made at
                    the right edges (i.e. preserving the "correct" topology of
                    the skeleton):

                      - "length" prioritizes cutting at long edges (requires
                        node positions as `vertex` to be provided)
                      - "radius" prioritizes cutting at edges with small radius
                        (requires `radii` to be provided)
                      - "degree" (default for `True`) prioritizes cutting at
                        edges between with low degree (i.e. non branch points)
                      - `True` will simply use a breadth-first search to
                        produce a directed graph without cycles.

    drop_disconnected : bool
                    Drops disconnected nodes from graph. Not recommended since
                    it breaks the vertex -> node mapping.
    weight :        bool
                    Whether to add edge lengths as weight to the final graph.
                    Requires `vertices` to be provided.
    radii :         (M, ) array, optional
                    Radii for each node. Only relevant if `fix_tree='radius'`.

    Returns
    -------
    G
                    networkx.DiGraph if `fix_tree` or networkx.Graph if not.

    """
    if fix_edges:
        # Drop self-loops
        edges = edges[edges[:, 0] != edges[:, 1]]

        # Make sure we don't have a->b and b<-a edges
        edges = unique(np.sort(edges, axis=1), axis=0)

    # Extract nodes from edges if not explicitly provided
    if isinstance(nodes, type(None)):
        nodes = unique(edges.flatten())

    # Start with undirected graph
    G = nx.Graph()
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)

    # Fix tree with MST if specified
    if isinstance(fix_tree, str):
        if fix_tree == 'radius':
            # Ty cutting at edges with small radii
            if isinstance(radii, type(None)):
                raise ValueError(
                    'Must provided `radii` with `fix_tree="radius"`')
            weights = 1 / np.vstack(
                (radii[edges[:, 0]], radii[edges[:, 1]])).mean(axis=0)
        elif fix_tree == 'length':
            # Ty cutting at long edges
            if isinstance(vertices, type(None)):
                raise ValueError(
                    'Must provided `vertices` with `fix_tree="length"`')
            vec = vertices[np.array(G.edges)[:, 0]] - vertices[np.array(
                G.edges)[:, 1]]
            weights = np.sqrt(np.sum(vec**2, axis=1))
        elif fix_tree == 'degree':
            # Else try cutting at edges with lowest degree
            weights = 1 / np.array(
                [max(G.degree[e[0]], G.degree[e[1]]) for e in edges])
        else:
            raise ValueError(f'Unknown mode for `fix_tree`: "{fix_tree}"')
        # Note we are inverting the weights so that we cut either at a low
        # radius or a low degree
        weights[weights <= 0] = weights[weights > 0].min() / 2
        nx.set_edge_attributes(G,
                               dict(zip([tuple(e) for e in edges], weights)),
                               name='weight')

        # Get the minimum spanning tree
        G = nx.minimum_spanning_tree(G, weight='weight')

    # Even if we already ran the MST, we still need to orient the tree
    # This by itself also "fixes"  the tree (i.e. breaks cycles) but it doesn't
    # give us much control over it
    if fix_tree:
        trees = []
        for cc in nx.connected_components(G):
            # Get subgraph of this component
            SG = nx.subgraph(G, cc)
            # Create an oriented tree
            trees.append(nx.bfs_tree(SG, source=list(SG.nodes)[0]))

        # Create the union of all trees
        if len(trees) > 1:
            # For some reason this is much faster than nx.compose
            G = nx.DiGraph()
            for t in trees:
                G.add_edges_from(list(t.edges))
                G.add_nodes_from(list(t.nodes))
        else:
            G = trees[0]

        # Reverse to child -> parent
        G = G.reverse()

    if drop_disconnected:
        # Array of degrees [[node_id, degree], [....], ...]
        deg = np.array(G.degree)
        G.remove_nodes_from(deg[deg[:, 1] == 0][:, 0])

    if weight and isinstance(vertices, np.ndarray):
        final_edges = np.array(G.edges)
        vec = vertices[final_edges[:, 0]] - vertices[final_edges[:, 1]]
        weights = np.sqrt(np.sum(vec**2, axis=1))
        G.remove_edges_from(list(G.edges))
        G.add_weighted_edges_from([(e[0], e[1], w)
                                   for e, w in zip(final_edges, weights)])

    return G
Ejemplo n.º 23
0
def mst_over_mesh(mesh, verts, limit='auto'):
    """Generate minimum spanning tree by subsetting mesh to given vertices.

    Will (re-)connect vertices based on geodesic distance in original mesh
    using a minimum spanning tree.

    Parameters
    ----------
    mesh :      trimesh.Trimesh
                Mesh to subset.
    verst :     iterable
                Vertex indices to keep for the tree.
    limit :     float | np.inf | "auto"
                Use this to limit the distance for shortest path search
                (``scipy.sparse.csgraph.dijkstra``). Can greatly speed up this
                function at the risk of producing disconnected components. By
                default (auto), we are using 3x the max observed Eucledian
                distance between ``verts``.

    Returns
    -------
    edges :     np.ndarray
                List of `node` -> `parent` edges. Note that these edges are
                already hiearchical, i.e. each node has at exactly 1 parent
                except for the root node(s) which has parent ``-1``.

    """
    # Make sure vertices to keep are unique
    keep = unique(verts)

    # Get some shorthands
    verts = mesh.vertices
    edges = mesh.edges_unique
    edge_lengths = mesh.edges_unique_length

    # Produce adjacency matrix from edges and edge lengths
    adj = scipy.sparse.coo_matrix((edge_lengths, (edges[:, 0], edges[:, 1])),
                                  shape=(verts.shape[0], verts.shape[0]))

    if limit == 'auto':
        distances = scipy.spatial.distance.pdist(verts[keep])
        limit = np.max(distances) * 3

    # Get geodesic distances between vertices
    dist_matrix = scipy.sparse.csgraph.dijkstra(csgraph=adj,
                                                directed=False,
                                                indices=keep,
                                                limit=limit)

    # Subset along second axis
    dist_matrix = dist_matrix[:, keep]

    # Get minimum spanning tree
    mst = scipy.sparse.csgraph.minimum_spanning_tree(dist_matrix,
                                                     overwrite=True)

    # Turn into COO matrix
    coo = mst.tocoo()

    # Extract edge list
    edges = np.array([keep[coo.row], keep[coo.col]]).T

    # Last but not least we have to run a depth first search to turn this
    # into a hierarchical tree, i.e. make edges are orientated in a way that
    # each node only has a single parent (turn a<-b->c into a->b->c)

    # Generate and populate undirected graph
    G = nx.Graph()
    G.add_edges_from(edges)

    # Generate list of parents
    edges = []
    # Go over all connected components
    for c in nx.connected_components(G):
        # Get subgraph of this connected component
        SG = nx.subgraph(G, c)

        # Use first node as root
        r = list(SG.nodes)[0]

        # List of parents: {node: [parent], root: []}
        this_lop = nx.predecessor(SG, r)

        # Note that we assign -1 as root's parent
        edges += [(k, v[0] if v else -1) for k, v in this_lop.items()]

    return np.array(edges).astype(int)