def __init__(self): self.graph = Graph([ Node(1, np.array([1, 1, 1])), Node(2, np.array([500, 500, 500])), Node(3, np.array([550, 550, 550])), ], [], GraphSpec(roi=Roi((-500, -500, -500), (1500, 1500, 1500))))
def __init__(self): self.graph = Graph( [Node(id=1, location=np.array([50, 70, 100]))], [], GraphSpec(roi=Roi((-200, -200, -200), (400, 400, 478))), )
def __init__(self): self.graph = Graph( [Node(id=1, location=np.array([500, 500, 500]))], [], GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))), )
def test_crop(self): g = Graph(self.nodes, self.edges, self.spec) sub_g = g.crop(Roi(Coordinate([1, 1, 1]), Coordinate([3, 3, 3]))) self.assertEqual(g.spec.roi, self.spec.roi) self.assertEqual(sub_g.spec.roi, Roi(Coordinate([1, 1, 1]), Coordinate([3, 3, 3]))) sub_g.spec.directed = False self.assertTrue(g.spec.directed) self.assertFalse(sub_g.spec.directed)
def __init__(self): self.dtype = float self.__vertices = [ Node(id=1, location=np.array([1, 1, 1], dtype=self.dtype)), Node(id=2, location=np.array([500, 500, 500], dtype=self.dtype)), Node(id=3, location=np.array([550, 550, 550], dtype=self.dtype)), ] self.__edges = [Edge(1, 2), Edge(2, 3)] self.__spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]), Coordinate([1500, 1500, 1500]))) self.graph = Graph(self.__vertices, self.__edges, self.__spec)
def setup(self): roi = Roi(Coordinate([0] * len(self.size)), self.size) for points_key in self.points: self.provides(points_key, GraphSpec(roi=roi, directed=self.directed)) k = min(self.size) nodes = [ Node(id=i, location=np.array([i * k / self.num_points] * 3)) for i in range(self.num_points) ] edges = [Edge(i, i + 1) for i in range(self.num_points - 1)] self.graph = Graph(nodes, edges, GraphSpec(roi=roi, directed=self.directed))
def provide(self, request): batch = Batch() roi_points = request[GraphKeys.TEST_POINTS].roi roi_array = request[ArrayKeys.TEST_LABELS].roi roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size data = np.zeros(roi_voxel.get_shape(), dtype=np.uint32) data[:, ::2] = 100 for node in self.points: loc = self.point_to_voxel(roi_array, node.location) data[loc] = node.id spec = self.spec[ArrayKeys.TEST_LABELS].copy() spec.roi = roi_array batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec) points = [] for node in self.points: if roi_points.contains(node.location): points.append(node) batch.graphs[GraphKeys.TEST_POINTS] = Graph(points, [], GraphSpec(roi=roi_points)) return batch
class TestSourceRandomLocation(BatchProvider): def __init__(self): self.graph = Graph( [ Node(id=1, location=np.array([1, 1, 1])), Node(id=2, location=np.array([500, 500, 500])), Node(id=3, location=np.array([550, 550, 550])), ], [], GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000))), ) def setup(self): self.provides(GraphKeys.TEST_GRAPH, self.graph.spec) def provide(self, request): batch = Batch() roi = request[GraphKeys.TEST_GRAPH].roi batch[GraphKeys.TEST_GRAPH] = self.graph.crop(roi).trim(roi) return batch
class TestPointSource(BatchProvider): def __init__(self, points: List[GraphKey], directed: bool, size: Coordinate, num_points: int): self.points = points self.directed = directed self.size = size self.num_points = num_points def setup(self): roi = Roi(Coordinate([0] * len(self.size)), self.size) for points_key in self.points: self.provides(points_key, GraphSpec(roi=roi, directed=self.directed)) k = min(self.size) nodes = [ Node(id=i, location=np.array([i * k / self.num_points] * 3)) for i in range(self.num_points) ] edges = [Edge(i, i + 1) for i in range(self.num_points - 1)] self.graph = Graph(nodes, edges, GraphSpec(roi=roi, directed=self.directed)) def provide(self, request: BatchRequest) -> Batch: batch = Batch() for points_key in self.points: if points_key in request: spec = request[points_key].copy() subgraph = self.graph.crop(roi=spec.roi) subgraph.relabel_connected_components() batch[points_key] = subgraph return batch
def test_nodes(): initial_locations = { 1: np.array([1, 1, 1], dtype=np.float32), 2: np.array([500, 500, 500], dtype=np.float32), 3: np.array([550, 550, 550], dtype=np.float32), } replacement_locations = { 1: np.array([0, 0, 0], dtype=np.float32), 2: np.array([50, 50, 50], dtype=np.float32), 3: np.array([55, 55, 55], dtype=np.float32), } nodes = [ Node(id=id, location=location) for id, location in initial_locations.items() ] edges = [Edge(1, 2), Edge(2, 3)] spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]), Coordinate([1500, 1500, 1500]))) graph = Graph(nodes, edges, spec) for node in graph.nodes: node.location = replacement_locations[node.id] for node in graph.nodes: assert all(np.isclose(node.location, replacement_locations[node.id]))
class ExampleGraphSource(BatchProvider): def __init__(self): self.dtype = float self.__vertices = [ Node(id=1, location=np.array([1, 1, 1], dtype=self.dtype)), Node(id=2, location=np.array([500, 500, 500], dtype=self.dtype)), Node(id=3, location=np.array([550, 550, 550], dtype=self.dtype)), ] self.__edges = [Edge(1, 2), Edge(2, 3)] self.__spec = GraphSpec(roi=Roi(Coordinate([-500, -500, -500]), Coordinate([1500, 1500, 1500]))) self.graph = Graph(self.__vertices, self.__edges, self.__spec) def setup(self): self.provides(GraphKeys.TEST_GRAPH, self.__spec) def provide(self, request): batch = Batch() roi = request[GraphKeys.TEST_GRAPH].roi sub_graph = self.graph.crop(roi) batch[GraphKeys.TEST_GRAPH] = sub_graph return batch
def process(self, batch, request): outputs = Batch() g = batch[self.graph].to_nx_graph() branch_points = [n for n in g.nodes if g.degree(n) > 2] for branch_point in branch_points: if g.is_directed(): successors = list(g.successors(branch_point)) predecessors = list(g.predecessors(branch_point)) lowest = min(successors + predecessors) for successor in successors: if successor != lowest: g.remove_edge(branch_point, successor) for predecessor in predecessors: if predecessor != lowest: g.remove_edge(predecessor, branch_point) else: neighbors = sorted(list(g.neighbors(branch_point))) for neighbor in neighbors[1:]: g.remove_edge(branch_point, neighbor) outputs[self.graph] = Graph.from_nx_graph( g, batch[self.graph].spec.copy()) return outputs
def test_shift_points5(self): data = [ Node(id=0, location=np.array([3, 0])), Node(id=1, location=np.array([3, 2])), Node(id=2, location=np.array([3, 4])), Node(id=3, location=np.array([3, 6])), Node(id=4, location=np.array([3, 8])), ] spec = GraphSpec(Roi(offset=(0, 0), shape=(15, 10))) points = Graph(data, [], spec) request_roi = Roi(offset=(3, 0), shape=(9, 10)) shift_array = np.array([[3, 0], [-3, 0], [0, 0], [-3, 0], [3, 0]], dtype=int) lcm_voxel_size = Coordinate((3, 2)) shifted_data = [ Node(id=0, location=np.array([6, 0])), Node(id=2, location=np.array([3, 4])), Node(id=4, location=np.array([6, 8])), ] result = ShiftAugment.shift_points( points, request_roi, shift_array, shift_axis=1, lcm_voxel_size=lcm_voxel_size, ) # print("test 4", result.data, shifted_data) self.assertTrue(self.points_equal(result.nodes, shifted_data)) self.assertTrue(result.spec == GraphSpec(request_roi))
def _empty_copy(self, base: Batch): add = Batch() for key, array in base.arrays.items(): add[key] = Array(np.zeros_like(array.data), spec=copy.deepcopy(array.spec)) for key, points in base.points.items(): add[key] = Graph([], [], spec=copy.deepcopy(points.spec)) return add
def provide(self, request): outputs = Batch() nodes = [ Node(id=0, location=np.array((1, 1, 1))), Node(id=1, location=np.array((10, 10, 10))), Node(id=2, location=np.array((19, 19, 19))), Node(id=3, location=np.array((21, 21, 21))), Node(id=104, location=np.array((30, 30, 30))), Node(id=5, location=np.array((39, 39, 39))), ] edges = [Edge(0, 1), Edge(1, 2), Edge(3, 104), Edge(104, 5)] spec = self.spec[GraphKeys.RAW].copy() spec.roi = request[GraphKeys.RAW].roi graph = Graph(nodes, edges, spec) outputs[GraphKeys.RAW] = graph.crop(spec.roi) return outputs
def process(self, batch, request): g = batch[self.graph].to_nx_graph() assert batch[self.graph].spec.roi.get_shape() == self.read_size logger.debug( f"{self.name()} got graph with {g.number_of_nodes()} nodes, and " f"{g.number_of_edges()} edges!") write_roi = batch[self.graph].spec.roi.grow(-self.context, -self.context) cc_func = (nx.connected_components if not g.is_directed() else nx.weakly_connected_components) for cc in cc_func(g): contained_nodes = [ n for n in cc if write_roi.contains(g.nodes[n]["location"]) ] if len(contained_nodes) == 0: continue else: cc_id = min(contained_nodes) cc_subgraph = g.subgraph(cc) # total edge length of this connected component in this write_roi total_edge_len = 0 for u, v in cc_subgraph.edges: u_loc = cc_subgraph.nodes[u]["location"] v_loc = cc_subgraph.nodes[v]["location"] edge_len = np.linalg.norm(u_loc - v_loc) if write_roi.contains(u_loc) and write_roi.contains(v_loc): total_edge_len += edge_len elif write_roi.contains(u_loc) or write_roi.contains( v_loc): total_edge_len += edge_len / 2 for u in contained_nodes: attrs = cc_subgraph.nodes[u] attrs[self.component_attr] = int(cc_id) attrs[self.size_attr] = float(total_edge_len) count = 0 for node, attrs in g.nodes.items(): if write_roi.contains(attrs["location"]): assert self.component_attr in attrs count += 1 logger.debug( f"{self.name()} updated component id of {count} nodes in write_roi" ) outputs = Batch() outputs[self.graph] = Graph.from_nx_graph( g, batch[self.graph].spec.copy()) return outputs
def _shift_and_crop(self, points: Graph, array: Array, direction: Coordinate, output_roi: Roi): # Shift and crop the array center = array.spec.roi.get_offset() + array.spec.roi.get_shape() // 2 new_center = center + direction new_offset = new_center - output_roi.get_shape() // 2 new_roi = Roi(new_offset, output_roi.get_shape()) array = array.crop(new_roi) array.spec.roi = output_roi new_points_data = {} new_points_spec = points.spec new_points_spec.roi = new_roi new_points_graph = nx.DiGraph() # shift points and add them to a graph for point_id, point in points.data.items(): if new_roi.contains(point.location): new_point = point.copy() new_point.location = (point.location - new_offset + output_roi.get_begin()) new_points_graph.add_node( new_point.point_id, point_id=new_point.point_id, parent_id=new_point.parent_id, location=new_point.location, label_id=new_point.label_id, radius=new_point.radius, point_type=new_point.point_type, ) if points.data.get( new_point.parent_id, False) and new_roi.contains( points.data[new_point.parent_id].location): new_points_graph.add_edge(new_point.parent_id, new_point.point_id) # relabel connected components for i, connected_component in enumerate( nx.weakly_connected_components(new_points_graph)): for node in connected_component: new_points_graph.nodes[node]["label_id"] = i # store new graph data in points new_points_data = { point_id: Node( point["location"], point_id=point["point_id"], point_type=point["point_type"], radius=point["radius"], parent_id=point["parent_id"], label_id=point["label_id"], ) for point_id, point in new_points_graph.nodes.items() } points = Graph(new_points_data, new_points_spec) points.spec.roi = output_roi return points, array
def provide(self, request): outputs = Batch() spec = self.graph_spec.copy() spec.roi = request[self.graph].roi outputs[self.graph] = Graph( self.component_1_nodes + self.component_2_nodes, self.component_1_edges + self.component_2_edges, spec, ) return outputs
def test_transpose(): voxel_size = Coordinate((20, 20)) graph_key = GraphKey("GRAPH") array_key = ArrayKey("ARRAY") graph = Graph( [Node(id=1, location=np.array([450, 550]))], [], GraphSpec(roi=Roi((100, 200), (800, 600))), ) data = np.zeros([40, 30]) data[17, 17] = 1 array = Array( data, ArraySpec(roi=Roi((100, 200), (800, 600)), voxel_size=voxel_size)) default_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs=[0, 0])) transpose_pipeline = ( (GraphSource(graph_key, graph), ArraySource(array_key, array)) + MergeProvider() + SimpleAugment( mirror_only=[], transpose_only=[0, 1], transpose_probs=[1, 1])) request = BatchRequest() request[graph_key] = GraphSpec(roi=Roi((400, 500), (200, 300))) request[array_key] = ArraySpec(roi=Roi((400, 500), (200, 300))) with build(default_pipeline): expected_location = [450, 550] batch = default_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}" with build(transpose_pipeline): expected_location = [410, 590] batch = transpose_pipeline.request_batch(request) assert len(list(batch[graph_key].nodes)) == 1 node = list(batch[graph_key].nodes)[0] assert all(np.isclose(node.location, expected_location)) node_voxel_index = Coordinate( (node.location - batch[array_key].spec.roi.get_offset()) / voxel_size) assert ( batch[array_key].data[node_voxel_index] == 1 ), f"Node at {np.where(batch[array_key].data == 1)} not {node_voxel_index}"
def test_shift_points1(self): data = [Node(id=1, location=np.array([0, 1]))] spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) points = Graph(data, [], spec) request_roi = Roi(offset=(0, 1), shape=(5, 3)) shift_array = np.array([[0, -1], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) lcm_voxel_size = Coordinate((1, 1)) shifted_points = Graph([], [], GraphSpec(request_roi)) result = ShiftAugment.shift_points( points, request_roi, shift_array, shift_axis=0, lcm_voxel_size=lcm_voxel_size, ) # print(result) self.assertTrue(self.points_equal(result.nodes, shifted_points.nodes)) self.assertTrue(result.spec == GraphSpec(request_roi))
def provide(self, request): batch = Batch() # have the pixels encode their position if ArrayKeys.RAW in request: # the z,y,x coordinates of the ROI roi = request[ArrayKeys.RAW].roi roi_voxel = roi // self.spec[ArrayKeys.RAW].voxel_size meshgrids = np.meshgrid( range(roi_voxel.get_begin()[0], roi_voxel.get_end()[0]), range(roi_voxel.get_begin()[1], roi_voxel.get_end()[1]), range(roi_voxel.get_begin()[2], roi_voxel.get_end()[2]), indexing="ij", ) data = meshgrids[0] + meshgrids[1] + meshgrids[2] spec = self.spec[ArrayKeys.RAW].copy() spec.roi = roi batch.arrays[ArrayKeys.RAW] = Array(data, spec) if ArrayKeys.GT_LABELS in request: roi = request[ArrayKeys.GT_LABELS].roi roi_voxel_shape = ( roi // self.spec[ArrayKeys.GT_LABELS].voxel_size ).get_shape() data = np.ones(roi_voxel_shape) data[roi_voxel_shape[0] // 2 :, roi_voxel_shape[1] // 2 :, :] = 2 data[roi_voxel_shape[0] // 2 :, -(roi_voxel_shape[1] // 2) :, :] = 3 spec = self.spec[ArrayKeys.GT_LABELS].copy() spec.roi = roi batch.arrays[ArrayKeys.GT_LABELS] = Array(data, spec) if GraphKeys.PRESYN in request: data_presyn, data_postsyn = self.__get_pre_and_postsyn_locations( roi=request[GraphKeys.PRESYN].roi ) elif GraphKeys.POSTSYN in request: data_presyn, data_postsyn = self.__get_pre_and_postsyn_locations( roi=request[GraphKeys.POSTSYN].roi ) voxel_size_points = self.spec[ArrayKeys.RAW].voxel_size for (graph_key, spec) in request.graph_specs.items(): if graph_key == GraphKeys.PRESYN: data = data_presyn if graph_key == GraphKeys.POSTSYN: data = data_postsyn batch.graphs[graph_key] = Graph( list(data.values()), [], GraphSpec(spec.roi) ) return batch
def process(self, batch, request): outputs = Batch() g = batch[self.graph].to_nx_graph() for node, attrs in list(g.nodes.items()): if attrs[self.size_attr] < self.size_threshold: g.remove_node(node) outputs[self.graph] = Graph.from_nx_graph( g, batch[self.graph].spec.copy()) return outputs
def provide(self, request): outputs = Batch() if self.n % self.every == 0: assert GraphKeys.TEST_GRAPH in request else: assert GraphKeys.TEST_GRAPH not in request for key, spec in request.items(): if isinstance(key, GraphKey): outputs[key] = Graph([], [], spec) if isinstance(key, ArrayKey): spec.voxel_size = self.spec[key].voxel_size outputs[key] = Array( np.zeros(spec.roi.get_shape(), dtype=spec.dtype), spec) self.n += 1 return outputs
def test_neighbors(self): # directed d_spec = self.spec # undirected ud_spec = self.spec ud_spec.directed = False directed = Graph(self.nodes, self.edges, d_spec) undirected = Graph(self.nodes, self.edges, ud_spec) self.assertCountEqual(directed.neighbors(self.nodes[0]), undirected.neighbors(self.nodes[0]))
def provide(self, request): # print("ScanTestSource: Got request " + str(request)) batch = Batch() # have the pixels encode their position for (array_key, spec) in request.array_specs.items(): roi = spec.roi roi_voxel = roi // self.spec[array_key].voxel_size # print("ScanTestSource: Adding " + str(array_key)) # the z,y,x coordinates of the ROI meshgrids = np.meshgrid(range(roi_voxel.get_begin()[0], roi_voxel.get_end()[0]), range(roi_voxel.get_begin()[1], roi_voxel.get_end()[1]), range(roi_voxel.get_begin()[2], roi_voxel.get_end()[2]), indexing='ij') data = meshgrids[0] + meshgrids[1] + meshgrids[2] # print("Roi is: " + str(roi)) spec = self.spec[array_key].copy() spec.roi = roi batch.arrays[array_key] = Array(data, spec) for graph_key, spec in request.graph_specs.items(): # node at x, y, z if x%100==0, y%10==0, z%10==0 nodes = [] start = spec.roi.get_begin() - tuple( x % s for x, s in zip(spec.roi.get_begin(), [100, 10, 10])) for i, j, k in itertools.product(*[ range(a, b, s) for a, b, s in zip( start, spec.roi.get_end(), [100, 10, 10]) ]): location = np.array([i, j, k]) if spec.roi.contains(location): nodes.append( Node(id=coordinate_to_id(i, j, k), location=location)) batch.graphs[graph_key] = Graph(nodes, [], spec) return batch
def process(self, batch, request): outputs = Batch() g = batch[self.graph].to_nx_graph() logger.debug(f"g has {len(g.nodes())} nodes pre filtering") cc_func = (nx.weakly_connected_components if g.is_directed() else nx.connected_components) ccs = cc_func(g) for cc in list(ccs): finished = False while not finished: finished = True g_component = g.subgraph(cc) branch_points = [ n for n in g_component.nodes if g_component.degree(n) > 2 ] logger.debug( f"Connected component has {len(g_component.nodes)} nodes and {len(branch_points)} branch points" ) removed = 0 for i, branch_point in enumerate(branch_points): remaining = [n for n in cc if n != branch_point] remaining_g = g_component.subgraph(remaining) remaining_ccs = list(cc_func(remaining_g)) logger.debug( f"After removing branch point {i}, cc is broken into pieces sized: {[len(x) for x in remaining_ccs]}" ) for remaining_cc in list(remaining_ccs): if (self.cable_len(g, list(remaining_cc) + [branch_point]) <= self.node_threshold): for n in remaining_cc: g.remove_node(n) finished = False removed += 1 logger.debug(f"Removed {removed} nodes from this cc") logger.debug(f"g has {len(g.nodes())} nodes post filtering") outputs[self.graph] = Graph.from_nx_graph( g, batch[self.graph].spec.copy()) return outputs
class ExampleSourceRandomLocation(BatchProvider): def __init__(self): self.graph = Graph([ Node(1, np.array([1, 1, 1])), Node(2, np.array([500, 500, 500])), Node(3, np.array([550, 550, 550])), ], [], GraphSpec(roi=Roi((-500, -500, -500), (1500, 1500, 1500)))) def setup(self): self.provides(GraphKeys.TEST_POINTS, self.graph.spec) def provide(self, request): batch = Batch() roi = request[GraphKeys.TEST_POINTS].roi batch[GraphKeys.TEST_POINTS] = self.graph.crop(roi).trim(roi) return batch
def process(self, batch, request): g = batch[self.graph].to_nx_graph() logger.debug( f"{self.name()} got graph with {g.number_of_nodes()} nodes, and " f"{g.number_of_edges()} edges!") write_roi = batch[self.graph].spec.roi.grow(-self.context, -self.context) contained_nodes = [ node for node, attr in g.nodes.items() if write_roi.contains(attr["location"]) ] contained_components = set(g.nodes[n][self.component_attr] for n in contained_nodes) logger.debug(f"Graph contains {len(contained_nodes)} nodes with " f"{len(contained_components)} components in write_roi") component_graph = self.client.get_graph(roi=write_roi, node_inclusion="dangling", edge_inclusion="either") for node in contained_nodes: attrs = g.nodes[node] block_component_id = attrs[self.component_attr] global_component_id = component_graph.nodes[block_component_id][ self.component_attr] attrs[self.component_attr] = global_component_id attrs[self.size_attr] = component_graph.nodes[block_component_id][ self.size_attr] logger.debug(f"Graph contains {len(contained_nodes)} nodes with " f"{len(contained_components)} components in write_roi") outputs = Batch() outputs[self.graph] = Graph.from_nx_graph( g, batch[self.graph].spec.copy()) return outputs
def process(self, batch, request): mst = batch[self.mst].to_nx_graph() dense_mst = batch[self.dense_mst].to_nx_graph() embeddings = batch[self.embeddings].data voxel_size = batch[self.embeddings].spec.voxel_size offset = batch[self.embeddings].spec.roi.get_begin() for (u, v), chain in self.get_edge_chains(mst, dense_mst): chain_embeddings = [] for n in chain: n_loc = dense_mst.nodes[n]["location"] n_ind = tuple(int(x) for x in ((n_loc - offset) // voxel_size)) chain_embeddings.append( embeddings[(slice(None), ) * (len(embeddings.shape) - 3) + n_ind]) mst.edges[(u, v)][self.distance_attr] = self.get_stat(chain) outputs = Batch() outputs[self.mst] = Graph.from_nx_graph(mst, batch[self.mst].spec) return outputs
class TestSource(BatchProvider): def __init__(self): self.graph = Graph( [Node(id=1, location=np.array([50, 70, 100]))], [], GraphSpec(roi=Roi((-200, -200, -200), (400, 400, 478))), ) def setup(self): self.provides(GraphKeys.TEST_GRAPH, self.graph.spec) def prepare(self, request): return request def provide(self, request): batch = Batch() roi = request[GraphKeys.TEST_GRAPH].roi batch[GraphKeys.TEST_GRAPH] = self.graph.crop(roi).trim(roi) return batch