def process(self, batch, request): outputs = gp.Batch() gt_graph = nx.Graph() mst_graph = nx.Graph() for block, block_specs in self.specs.items(): ground_truth_key = block_specs["ground_truth"][0] mst_key = block_specs["mst_pred"][0] block_gt_graph = batch[ground_truth_key].to_nx_graph( ).to_undirected() block_mst_graph = batch[mst_key].to_nx_graph().to_undirected() gt_graph = nx.disjoint_union(gt_graph, block_gt_graph) mst_graph = nx.disjoint_union(mst_graph, block_mst_graph) for node, attrs in gt_graph.nodes.items(): attrs["id"] = node for node, attrs in mst_graph.nodes.items(): attrs["id"] = node outputs[self.gt] = gp.Graph.from_nx_graph( gt_graph, gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3), directed=False)) outputs[self.mst] = gp.Graph.from_nx_graph( mst_graph, gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3), directed=False), ) return outputs
def prepare(self, request): deps = gp.BatchRequest() deps[self.mst] = gp.GraphSpec(roi=self.roi) deps[self.gt] = gp.GraphSpec(roi=self.roi) if self.connectivity is not None: deps[self.connectivity] = gp.GraphSpec(roi=self.roi) return deps
def setup(self): # provide points in an infinite ROI self.graph_spec = gp.GraphSpec( roi=gp.Roi(offset=(0, ) * self.dims, shape=(None, ) * self.dims)) self.provides(self.graph_key, self.graph_spec)
def setup(self): self.ndims = self.data.shape[1] if self.points_spec is not None: self.provides(self.points, self.points_spec) elif isinstance(self.points, gp.ArrayKey): self.provides(self.points, gp.ArraySpec(voxel_size=((1, )))) elif isinstance(self.points, gp.GraphKey): print(self.ndims) min_bb = gp.Coordinate( np.floor(np.amin(self.data[:, :self.ndims], 0))) max_bb = gp.Coordinate( np.ceil(np.amax(self.data[:, :self.ndims], 0)) + 1) roi = gp.Roi(min_bb, max_bb - min_bb) logger.debug(f"Bounding Box: {roi}") self.provides(self.points, gp.GraphSpec(roi=roi)) if self.labels is not None: assert isinstance(self.labels, gp.ArrayKey), \ f"Label key must be an ArrayKey, \ was given {type(self.labels)}" if self.labels_spec is not None: self.provides(self.labels, self.labels_spec) else: self.provides(self.labels, gp.ArraySpec(voxel_size=((1, ))))
def setup(self): if str(self.snapshot_file).endswith(".h5") or str(self.snapshot_file).endswith( ".hdf" ): data = h5py.File(self.snapshot_file, "r") elif str(self.snapshot_file).endswith(".zarr"): data = zarr.open(self.snapshot_file, "r") for key, path in self.datasets.items(): if isinstance(key, gp.ArrayKey): try: x = data[path] except KeyError: raise KeyError(f"Could not find {path}") spec = self.spec_from_dataset(x) self.provides(key, spec) elif isinstance(key, gp.GraphKey): try: locations = data[f"{path}-locations"] except KeyError: raise KeyError(f"Could not find {path}-locations") spec = gp.GraphSpec( gp.Roi((None,) * len(locations[0]), (None,) * len(locations[0])), directed=self.directed.get(key), ) self.provides(key, spec)
def provide(self, request): roi = request[self.graph_key].roi random_points = self.random_point_generator.get_random_points(roi) batch = gp.Batch() batch[self.graph_key] = gp.Graph( [gp.Node(id=i, location=l) for i, l in random_points.items()], [], gp.GraphSpec(roi=roi, directed=False)) return batch
def setup(self): self.enable_autoskip() all_rois = [] for block, block_specs in self.specs.items(): ground_truth = block_specs["ground_truth"] mst_pred = block_specs["mst_pred"] for key, spec in [ground_truth, mst_pred]: current_spec = self.spec[key].copy() current_spec.roi = spec.roi self.updates(key, current_spec) all_rois.append(current_spec.roi) self.total_roi = all_rois[0] for roi in all_rois[1:]: self.total_roi = self.total_roi.union(roi) self.provides( self.mst, gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3), directed=False)) self.provides( self.gt, gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3), directed=False))
def provide(self, request): timing = Timing(self) timing.start() batch = gp.Batch() # If a Array is requested then we will randomly choose # the number of requested points if isinstance(self.points, gp.ArrayKey): points = np.random.choice(self.data.shape[0], self.num_points) data = self.data[points][np.newaxis] if self.scale is not None: data = data * self.scale if self.label_data is not None: labels = self.label_data[points] batch[self.points] = gp.Array(data, self.spec[self.points]) else: # If a graph is request we must select points within the # request ROI min_bb = request[self.points].roi.get_begin() max_bb = request[self.points].roi.get_end() logger.debug("Points source got request for %s", request[self.points].roi) point_filter = np.ones((self.data.shape[0], ), dtype=np.bool) for d in range(self.ndims): point_filter = np.logical_and(point_filter, self.data[:, d] >= min_bb[d]) point_filter = np.logical_and(point_filter, self.data[:, d] < max_bb[d]) points_data, labels = self._get_points(point_filter) logger.debug(f"Found {len(points_data)} points") points_spec = gp.GraphSpec(roi=request[self.points].roi.copy()) batch.graphs[self.points] = gp.Graph(points_data, [], points_spec) # Labels will always be an Array if self.label_data is not None: batch[self.labels] = gp.Array(labels, self.spec[self.labels]) timing.stop() batch.profiling_stats.add(timing) return batch
def graph_from_path(self, graph_key, data, path): saved_ids = data[f"{path}-ids"] saved_edges = data[f"{path}-edges"] saved_locations = data[f"{path}-locations"] node_attrs = [ (attr, data[f"{path}/node_attrs/{attr}"]) for attr in self.node_attrs.get(graph_key, []) ] attrs = [attr for attr, _ in node_attrs] attr_values = zip( *[values for _, values in node_attrs], (None,) * len(saved_locations) ) nodes = [ gp.Node( node_id, location=location, attrs={attr: value for attr, value in zip(attrs, values)}, ) for node_id, location, values in zip( saved_ids, saved_locations, attr_values ) ] edge_attrs = [ (attr, data[f"{path}/edge_attrs/{attr}"]) for attr in self.edge_attrs.get(graph_key, []) ] attrs = [attr for attr, _ in edge_attrs] attr_values = zip( *[values for _, values in edge_attrs], (None,) * len(saved_edges) ) edges = [ gp.Edge(u, v, attrs={attr: value for attr, value in zip(attrs, values)}) for (u, v), values in zip(saved_edges, attr_values) ] return gp.Graph( nodes, edges, gp.GraphSpec( gp.Roi( (None,) * len(saved_locations[0]), (None,) * len(saved_locations[0]) ), directed=self.directed.get(graph_key), ), )
def get_requests(config, blocks, raw, emb_pred, labels, gt): voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape diff = input_size - output_size cube_rois = [get_cube_roi(config, block) for block in blocks] requests = [] for cube_roi in cube_rois: context_roi = cube_roi.grow(diff // 2, diff // 2) request = gp.BatchRequest() request[raw] = gp.ArraySpec(roi=context_roi) request[emb_pred] = gp.ArraySpec(roi=cube_roi) request[labels] = gp.ArraySpec(roi=cube_roi) request[gt] = gp.GraphSpec(roi=cube_roi) requests.append(request) return requests
def validation_data_sources_from_snapshots(config, blocks): validation_blocks = Path(config["VALIDATION_BLOCKS"]) raw = gp.ArrayKey("RAW") ground_truth = gp.GraphKey("GROUND_TRUTH") labels = gp.ArrayKey("LABELS") voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape block_pipelines = [] for block in blocks: pipelines = ( SnapshotSource( validation_blocks / f"block_{block}.hdf", { labels: "volumes/labels", ground_truth: "points/gt" }, directed={ground_truth: True}, ), SnapshotSource(validation_blocks / f"block_{block}.hdf", {raw: "volumes/raw"}), ) cube_roi = get_cube_roi(config, block) request = gp.BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) request[raw] = gp.ArraySpec(input_roi) request[ground_truth] = gp.GraphSpec(cube_roi) request[labels] = gp.ArraySpec(cube_roi) block_pipelines.append((pipelines, request)) return block_pipelines, (raw, labels, ground_truth)
def validation_data_sources_recomputed(config, blocks): benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape validation_dirs = {} for group in benchmark_datasets_path.iterdir(): if "validation" in group.name and group.is_dir(): for validation_dir in group.iterdir(): validation_num = int(validation_dir.name.split("_")[-1]) if validation_num in blocks: validation_dirs[validation_num] = validation_dir validation_dirs = [validation_dirs[block] for block in blocks] raw = gp.ArrayKey("RAW") ground_truth = gp.GraphKey("GROUND_TRUTH") labels = gp.ArrayKey("LABELS") validation_pipelines = [] for validation_dir in validation_dirs: trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) pipeline = (( gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ), nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ), ) + gp.nodes.MergeProvider() + gp.Normalize( raw, dtype=np.float32) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels(labels, radii=[neuron_width * 1000])) request = gp.BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) print(f"input_roi has shape: {input_roi.get_shape()}") print(f"cube_roi has shape: {cube_roi.get_shape()}") request[raw] = gp.ArraySpec(input_roi) request[ground_truth] = gp.GraphSpec(cube_roi) request[labels] = gp.ArraySpec(cube_roi) validation_pipelines.append((pipeline, request)) return validation_pipelines, (raw, labels, ground_truth)
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") raw_clahed = gp.ArrayKey(f"RAW_CLAHED_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={ raw: "volume-rechunked", raw_clahed: "volume-rechunked" }, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), raw_clahed: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size), }, ) + gp.Normalize(raw, dtype=np.float32) + gp.Normalize(raw_clahed, dtype=np.float32) + scipyCLAHE([raw_clahed], [20, 64, 64])) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec[raw_clahed] = gp.ArraySpec(input_roi) additional_request[raw_clahed] = gp.ArraySpec(roi=input_roi) block_spec[ground_truth] = gp.GraphSpec(cube_roi_shifted) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi_shifted) block_spec[labels] = gp.ArraySpec(cube_roi_shifted) additional_request[labels] = gp.ArraySpec(roi=cube_roi_shifted) pipeline = ((swc_source, raw_source) + gp.nodes.MergeProvider() + gp.SpecifiedLocation(locations=[cube_roi.get_center()]) + gp.Crop(raw, roi=input_roi) + gp.Crop(raw_clahed, roi=input_roi) + gp.Crop(ground_truth, roi=cube_roi_shifted) + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + gp.Crop(labels, roi=cube_roi_shifted) + gp.Snapshot( { raw: f"volumes/{block}/raw", raw_clahed: f"volumes/{block}/raw_clahe", ground_truth: f"points/{block}/ground_truth", labels: f"volumes/{block}/labels", }, additional_request=additional_request, output_dir="validations", output_filename="validations.hdf", )) validation_pipelines.append(pipeline) validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + gp.PrintProfilingStats()) return validation_pipeline, specs
def process(self, batch, request): num_thresholds = self.num_thresholds threshold_range = self.threshold_range outputs = gp.Batch() gt_graph = batch[self.gt].to_nx_graph().to_undirected() mst_graph = batch[self.mst].to_nx_graph().to_undirected() if self.connectivity is not None: connectivity_graph = batch[ self.connectivity].to_nx_graph().to_undirected() # assert mst_graph.number_of_nodes() > 0, f"mst_graph is empty!" if self.details is not None: matching_details_graph = nx.Graph() if mst_graph.number_of_nodes() == 0: node_offset = max([node for node in mst_graph.nodes] + [-1]) + 1 label_offset = len(list( nx.connected_components(mst_graph))) + 1 for node, attrs in mst_graph.nodes.items(): matching_details_graph.add_node(node, **copy.deepcopy(attrs)) for edge, attrs in mst_graph.edges.items(): matching_details_graph.add_edge(edge[0], edge[1], **copy.deepcopy(attrs)) for node, attrs in gt_graph.nodes.items(): matching_details_graph.add_node(node + node_offset, **copy.deepcopy(attrs)) matching_details_graph.nodes[node + node_offset]["id"] = ( node + node_offset) for edge, attrs in gt_graph.edges.items(): matching_details_graph.add_edge(edge[0] + node_offset, edge[1] + node_offset, **copy.deepcopy(attrs)) edges = [(edge, attrs[self.edge_threshold_attr]) for edge, attrs in mst_graph.edges.items()] edges = list(sorted(edges, key=lambda x: x[1])) edge_lens = [e[1] for e in edges] # min_threshold = edges[0][1] if len(edge_lens) > 0: min_threshold = edge_lens[int(len(edge_lens) * threshold_range[0])] max_threshold = edge_lens[int(len(edge_lens) * threshold_range[1]) - 1] else: min_threshold = 0 max_threshold = 1 thresholds = np.linspace(min_threshold, max_threshold, num=num_thresholds) current_threshold_mst = nx.Graph() edge_deque = deque(edges) edit_distances = [] split_costs = [] merge_costs = [] false_pos_costs = [] false_neg_costs = [] num_nodes = [] num_edges = [] best_score = None best_graph = None for threshold_index, threshold in enumerate(thresholds): logger.warning(f"Using threshold: {threshold}") while len(edge_deque) > 0 and edge_deque[0][1] <= threshold: (u, v), _ = edge_deque.popleft() attrs = mst_graph.edges[(u, v)] current_threshold_mst.add_edge(u, v) current_threshold_mst.add_node(u, **mst_graph.nodes[u]) current_threshold_mst.add_node(v, **mst_graph.nodes[v]) if self.connectivity is not None: temp = nx.Graph() next_node = max([node for node in connectivity_graph.nodes]) + 1 for i, cc in enumerate( nx.connected_components(current_threshold_mst)): component_subgraph = current_threshold_mst.subgraph(cc) for node in component_subgraph.nodes: temp.add_node(node, **dict(connectivity_graph.nodes[node])) temp.nodes[node]["component"] = i for edge in connectivity_graph.edges: if (edge[0] in temp.nodes and edge[1] in temp.nodes and temp.nodes[edge[0]]["component"] == temp.nodes[edge[1]]["component"]): temp.add_edge( edge[0], edge[1], **dict(connectivity_graph.edges[edge])) elif False: path = nx.shortest_path(connectivity_graph, edge[0], edge[1]) cloned_path = [] for node in path: if node in temp.nodes: cloned_path.append(node) else: cloned_path.append(next_node) next_node += 1 path_len = len(cloned_path) - 1 for i, j in zip(range(path_len), range(1, path_len + 1)): u = cloned_path[i] if u not in temp.nodes: temp.add_node( u, **dict( connectivity_graph.nodes[path[i]])) v = cloned_path[j] if v not in temp.nodes: temp.add_node( v, **dict( connectivity_graph.nodes[path[j]])) temp.add_edge( u, v, **dict(connectivity_graph.edges[path[i], path[j]]), ) else: temp = copy.deepcopy(current_threshold_mst) for i, cc in enumerate(nx.connected_components(temp)): for node in cc: attrs = temp.nodes[node] attrs["component"] = i # remove small connected_components false_pos_nodes = [] for cc in nx.connected_components(temp): cc_graph = temp.subgraph(cc) min_loc = None max_loc = None for node, attrs in cc_graph.nodes.items(): node_loc = attrs[self.location_attr] if min_loc is None: min_loc = node_loc else: min_loc = np.min(np.array([node_loc, min_loc]), axis=0) if max_loc is None: max_loc = node_loc else: max_loc = np.max(np.array([node_loc, max_loc]), axis=0) if np.linalg.norm(min_loc - max_loc) < self.small_component_threshold: false_pos_nodes += list(cc) for node in false_pos_nodes: temp.remove_node(node) nodes_x = list(temp.nodes) nodes_y = list(gt_graph.nodes) node_labels_x = { node: attrs["component"] for node, attrs in temp.nodes.items() } node_labels_y = { node: component for component, nodes in enumerate( nx.connected_components(gt_graph)) for node in nodes } edges_yx = get_edges_xy( gt_graph, temp, location_attr=self.location_attr, node_match_threshold=self.comatch_threshold, ) edges_xy = [(v, u) for u, v in edges_yx] result = match_components( nodes_x, nodes_y, edges_xy, copy.deepcopy(node_labels_x), copy.deepcopy(node_labels_y), ) if self.details is not None: # add a match details graph to the batch # details is a graph containing nodes from both mst and gt # to access details of a node, use `details.nodes[node]["details"]` # where the details returned are a numpy array of shape (num_thresholds, _). # the _ values stored per threshold are success, fp, fn, merge, split, selected, mst, gt # NODES # success, mst: n matches to only nodes of 1 label, which matches its own label # success, gt: n matches to only nodes of 1 label, which matches its own label # fp: n in mst matches to nothing # fn: n in gt matches to nothing # merge: n in mst matches to a node with label not matching its own # split: n in gt matches to a node with label not matching its own # selected: n in mst in thresholded graph # EDGES # success, mst: both endpoints successful # success, gt: both endpoints successful # fp: both endpoints fp # fn: both endpoints fn # merge: e in mst: only one endpoint successful # split: e in gt: only one endpoint successful # selected: e in mst in thresholded graph (label_matches, node_matches, splits, merges, fps, fns) = result # create lookup tables: x_label_match_lut = {} y_label_match_lut = {} for a, b in label_matches: x_matches = x_label_match_lut.setdefault(a, set()) x_matches.add(b) y_matches = y_label_match_lut.setdefault(b, set()) y_matches.add(a) x_node_match_lut = {} y_node_match_lut = {} for a, b in node_matches: x_matches = x_node_match_lut.setdefault(a, set()) x_matches.add(b) y_matches = y_node_match_lut.setdefault(b, set()) y_matches.add(a) for node, attrs in matching_details_graph.nodes.items(): gt = int(node >= node_offset) mst = 1 - gt if gt == 1: node = node - node_offset selected = gt or (node in temp.nodes()) if selected: success, fp, fn, merge, split, label_pair = self.node_matching_result( node, gt, x_label_match_lut, y_label_match_lut, x_node_match_lut, y_node_match_lut, node_labels_x, node_labels_y, ) else: success, fp, fn, merge, split, label_pair = ( 0, 0, 0, 0, 0, (-1, -1), ) data = attrs.setdefault( "details", np.zeros((len(thresholds), 7), dtype=bool)) data[threshold_index] = [ selected, success, fp, fn, merge, split, gt, ] label_pairs = attrs.setdefault("label_pair", []) label_pairs.append(label_pair) assert len(label_pairs) == threshold_index + 1 for (u, v), attrs in matching_details_graph.edges.items(): ( u_selected, u_success, u_fp, u_fn, u_merge, u_split, u_gt, ) = matching_details_graph.nodes[u]["details"][ threshold_index] ( v_selected, v_success, v_fp, v_fn, v_merge, v_split, v_gt, ) = matching_details_graph.nodes[v]["details"][ threshold_index] assert u_gt == v_gt e_gt = u_gt u_label_pair = matching_details_graph.nodes[u][ "label_pair"][threshold_index] v_label_pair = matching_details_graph.nodes[v][ "label_pair"][threshold_index] e_selected = u_selected and v_selected e_success = (e_selected and u_success and v_success and (u_label_pair == v_label_pair)) e_fp = u_fp and v_fp e_fn = u_fn and v_fn e_merge = e_selected and (not e_success) and ( not e_fp) and not e_gt e_split = e_selected and (not e_success) and ( not e_fn) and e_gt assert not (e_success and e_merge) assert not (e_success and e_split) data = attrs.setdefault( "details", np.zeros((len(thresholds), 7), dtype=bool)) if e_success: label_pairs = attrs.setdefault("label_pair", []) label_pairs.append(u_label_pair) assert len(label_pairs) == threshold_index + 1 else: label_pairs = attrs.setdefault("label_pair", []) label_pairs.append((-1, -1)) assert len(label_pairs) == threshold_index + 1 data[threshold_index] = [ e_selected, e_success, e_fp, e_fn, e_merge, e_split, e_gt, ] edit_distance, ( split_cost, merge_cost, false_pos_cost, false_neg_cost, ) = psudo_graph_edit_distance( result[1], node_labels_x, node_labels_y, temp, gt_graph, self.location_attr, node_spacing=self.edit_distance_node_spacing, details=True, ) edit_distances.append(edit_distance) split_costs.append(split_cost) merge_costs.append(merge_cost) false_pos_costs.append(false_pos_cost) false_neg_costs.append(false_neg_cost) num_nodes.append(len(temp.nodes)) num_edges.append(len(temp.edges)) # save the best version: if best_score is None: best_score = edit_distance best_graph = copy.deepcopy(temp) elif edit_distance < best_score: best_score = edit_distance best_graph = copy.deepcopy(temp) outputs[self.output] = gp.Array( np.array([ edit_distances, thresholds, num_nodes, num_edges, split_costs, merge_costs, false_pos_costs, false_neg_costs, ]), gp.ArraySpec(nonspatial=True), ) if self.output_graph is not None: outputs[self.output_graph] = gp.Graph.from_nx_graph( best_graph, gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False)) if self.details is not None: outputs[self.details] = gp.Graph.from_nx_graph( matching_details_graph, gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False), ) return outputs
def setup(self): spec = gp.GraphSpec(roi=self.spec[self.array_key].roi) self.provides(self.graph_key, spec)
def setup(self): spec = gp.GraphSpec(roi=self.spec[self.array_key].roi) self.provides(self.graph_key, spec) self.center = np.array(self.spec[self.array_key].roi.get_center())
def validation_pipeline(config): """ Per block { Raw -> predict -> scan gt -> rasterize -> merge -> candidates -> trees } -> merge -> comatch + evaluate """ blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] sample_dir = Path(config["SAMPLES_PATH"]) raw_n5 = config["RAW_N5"] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" neuron_width = int(config["NEURON_RADIUS"]) voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] candidate_threshold = config["NMS_THRESHOLD"] candidate_spacing = min(config["NMS_WINDOW_SIZE"]) * micron_scale coordinate_scale = config["COORDINATE_SCALE"] * np.array( voxel_size) / micron_scale emb_model = get_emb_model(config) fg_model = get_fg_model(config) validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array([300, 300, 1000]), ) raw = gp.ArrayKey(f"RAW_{block}") ground_truth = gp.GraphKey(f"GROUND_TRUTH_{block}") labels = gp.ArrayKey(f"LABELS_{block}") candidates = gp.ArrayKey(f"CANDIDATES_{block}") mst = gp.GraphKey(f"MST_{block}") raw_source = (gp.ZarrSource( filename=str(Path(sample_dir, sample, raw_n5).absolute()), datasets={raw: "volume-rechunked"}, array_specs={ raw: gp.ArraySpec(interpolatable=True, voxel_size=voxel_size) }, ) + gp.Normalize(raw, dtype=np.float32) + mCLAHE([raw], [20, 64, 64])) emb_source, emb = add_emb_pred(config, raw_source, raw, block, emb_model) pred_source, fg = add_fg_pred(config, emb_source, raw, block, fg_model) pred_source = add_scan(pred_source, { raw: input_size, emb: output_size, fg: output_size }) swc_source = nl.gunpowder.nodes.MouselightSwcFileSource( validation_dir, [ground_truth], transform_file=transform_template.format(sample=sample), ignore_human_nodes=False, scale=voxel_size, transpose=[2, 1, 0], points_spec=[ gp.PointsSpec(roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), )) ], ) additional_request = BatchRequest() input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec["raw"] = (raw, gp.ArraySpec(input_roi)) additional_request[raw] = gp.ArraySpec(roi=input_roi) block_spec["ground_truth"] = (ground_truth, gp.GraphSpec(cube_roi)) additional_request[ground_truth] = gp.GraphSpec(roi=cube_roi) block_spec["labels"] = (labels, gp.ArraySpec(cube_roi)) additional_request[labels] = gp.ArraySpec(roi=cube_roi) block_spec["fg_pred"] = (fg, gp.ArraySpec(cube_roi)) additional_request[fg] = gp.ArraySpec(roi=cube_roi) block_spec["emb_pred"] = (emb, gp.ArraySpec(cube_roi)) additional_request[emb] = gp.ArraySpec(roi=cube_roi) block_spec["candidates"] = (candidates, gp.ArraySpec(cube_roi)) additional_request[candidates] = gp.ArraySpec(roi=cube_roi) block_spec["mst_pred"] = (mst, gp.GraphSpec(cube_roi)) additional_request[mst] = gp.GraphSpec(roi=cube_roi) pipeline = ((swc_source, pred_source) + gp.nodes.MergeProvider() + nl.gunpowder.RasterizeSkeleton( ground_truth, labels, connected_component_labeling=True, array_spec=gp.ArraySpec( voxel_size=voxel_size, dtype=np.int64, roi=gp.Roi( gp.Coordinate([None, None, None]), gp.Coordinate([None, None, None]), ), ), ) + nl.gunpowder.GrowLabels( labels, radii=[neuron_width * micron_scale]) + Skeletonize(fg, candidates, candidate_spacing, candidate_threshold) + EMST( emb, candidates, mst, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) + gp.Snapshot( { raw: f"volumes/{raw}", ground_truth: f"points/{ground_truth}", labels: f"volumes/{labels}", fg: f"volumes/{fg}", emb: f"volumes/{emb}", candidates: f"volumes/{candidates}", mst: f"points/{mst}", }, additional_request=additional_request, output_dir="snapshots", output_filename="{id}.hdf", edge_attrs={mst: [distance_attr]}, )) validation_pipelines.append(pipeline) full_gt = gp.GraphKey("FULL_GT") full_mst = gp.GraphKey("FULL_MST") score = gp.ArrayKey("SCORE") validation_pipeline = ( tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeGraphs(specs, full_gt, full_mst) + Evaluate(full_gt, full_mst, score, edge_threshold_attr=distance_attr) + gp.PrintProfilingStats()) return validation_pipeline, score
def emb_validation_pipeline( config, snapshot_file, candidates_path, raw_path, gt_path, candidates_mst_path=None, candidates_mst_dense_path=None, path_stat="max", ): checkpoint = config["EMB_EVAL_CHECKPOINT"] blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) micron_scale = max(voxel_size) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape distance_attr = config["DISTANCE_ATTR"] coordinate_scale = config["COORDINATE_SCALE"] * np.array( voxel_size) / micron_scale num_thresholds = config["NUM_EVAL_THRESHOLDS"] threshold_range = config["EVAL_THRESHOLD_RANGE"] edge_threshold_0 = config["EVAL_EDGE_THRESHOLD_0"] component_threshold_0 = config["COMPONENT_THRESHOLD_0"] component_threshold_1 = config["COMPONENT_THRESHOLD_1"] clip_limit = config["CLAHE_CLIP_LIMIT"] normalize = config["CLAHE_NORMALIZE"] validation_pipelines = [] specs = {} emb_model = get_emb_model(config) emb_model.eval() for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array(voxel_size[::-1]), ) candidates_1 = gp.ArrayKey(f"CANDIDATES_1_{block}") raw = gp.ArrayKey(f"RAW_{block}") mst_0 = gp.GraphKey(f"MST_0_{block}") mst_dense_0 = gp.GraphKey(f"MST_DENSE_0_{block}") mst_1 = gp.GraphKey(f"MST_1_{block}") mst_dense_1 = gp.GraphKey(f"MST_DENSE_1_{block}") mst_2 = gp.GraphKey(f"MST_2_{block}") mst_dense_2 = gp.GraphKey(f"MST_DENSE_2_{block}") gt = gp.GraphKey(f"GT_{block}") score = gp.ArrayKey(f"SCORE_{block}") details = gp.GraphKey(f"DETAILS_{block}") optimal_mst = gp.GraphKey(f"OPTIMAL_MST_{block}") # Volume Source raw_source = SnapshotSource( snapshot_file, datasets={ raw: raw_path.format(block=block), candidates_1: candidates_path.format(block=block), }, ) # Graph Source graph_datasets = {gt: gt_path.format(block=block)} graph_directionality = {gt: False} edge_attrs = {} if candidates_mst_path is not None: graph_datasets[mst_0] = candidates_mst_path.format(block=block) graph_directionality[mst_0] = False edge_attrs[mst_0] = [distance_attr] if candidates_mst_dense_path is not None: graph_datasets[mst_dense_0] = candidates_mst_dense_path.format( block=block) graph_directionality[mst_dense_0] = False edge_attrs[mst_dense_0] = [distance_attr] gt_source = SnapshotSource( snapshot_file, datasets=graph_datasets, directed=graph_directionality, edge_attrs=edge_attrs, ) if config["EVAL_CLAHE"]: raw_source = raw_source + scipyCLAHE( [raw], gp.Coordinate([20, 64, 64]) * voxel_size, clip_limit=clip_limit, normalize=normalize, ) else: pass emb_source, emb, neighborhood = add_emb_pred(config, raw_source, raw, block, emb_model) reference_sizes = { raw: input_size, emb: output_size, candidates_1: output_size } if neighborhood is not None: reference_sizes[neighborhood] = output_size emb_source = add_scan(emb_source, reference_sizes) input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) block_spec[candidates_1] = gp.ArraySpec(cube_roi_shifted) block_spec[emb] = gp.ArraySpec(cube_roi_shifted) if neighborhood is not None: block_spec[neighborhood] = gp.ArraySpec(cube_roi_shifted) block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_0] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_dense_0] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_1] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_dense_1] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst_2] = gp.GraphSpec(cube_roi_shifted, directed=False) # block_spec[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[score] = gp.ArraySpec(nonspatial=True) block_spec[optimal_mst] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request = BatchRequest() additional_request[raw] = gp.ArraySpec(input_roi) additional_request[candidates_1] = gp.ArraySpec(cube_roi_shifted) additional_request[emb] = gp.ArraySpec(cube_roi_shifted) if neighborhood is not None: additional_request[neighborhood] = gp.ArraySpec(cube_roi_shifted) additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_0] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_dense_0] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_1] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_dense_1] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst_2] = gp.GraphSpec(cube_roi_shifted, directed=False) # additional_request[mst_dense_2] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[details] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[optimal_mst] = gp.GraphSpec(cube_roi_shifted, directed=False) pipeline = (emb_source, gt_source) + gp.MergeProvider() if candidates_mst_path is not None and candidates_mst_dense_path is not None: # mst_0 provided, just need to calculate distances. pass elif config["EVAL_MINIMAX_EMBEDDING_DIST"]: # No mst_0 provided, must first calculate mst_0 and dense mst_0 pipeline += MiniMaxEmbeddings( emb, candidates_1, decimated=mst_0, dense=mst_dense_0, distance_attr=distance_attr, ) else: # mst/mst_dense not provided. Simply use euclidean distance on candidates pipeline += EMST( emb, candidates_1, mst_0, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) pipeline += EMST( emb, candidates_1, mst_dense_0, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) pipeline += ThresholdEdges( (mst_0, mst_1), edge_threshold_0, component_threshold_0, msts_dense=(mst_dense_0, mst_dense_1), distance_attr=distance_attr, ) pipeline += ComponentWiseEMST( emb, mst_1, mst_2, distance_attr=distance_attr, coordinate_scale=coordinate_scale, ) # pipeline += ScoreEdges( # mst, mst_dense, emb, distance_attr=distance_attr, path_stat=path_stat # ) pipeline += Evaluate( gt, mst_2, score, roi=cube_roi_shifted, details=details, edge_threshold_attr=distance_attr, num_thresholds=num_thresholds, threshold_range=threshold_range, small_component_threshold=component_threshold_1, # connectivity=mst_1, output_graph=optimal_mst, ) if config["EVAL_SNAPSHOT"]: snapshot_datasets = { raw: f"volumes/raw", emb: f"volumes/embeddings", candidates_1: f"volumes/candidates_1", mst_0: f"points/mst_0", mst_dense_0: f"points/mst_dense_0", mst_1: f"points/mst_1", mst_dense_1: f"points/mst_dense_1", # mst_2: f"points/mst_2", gt: f"points/gt", details: f"points/details", optimal_mst: f"points/optimal_mst", } if neighborhood is not None: snapshot_datasets[neighborhood] = f"volumes/neighborhood" pipeline += gp.Snapshot( snapshot_datasets, output_dir=config["EVAL_SNAPSHOT_DIR"], output_filename=config["EVAL_SNAPSHOT_NAME"].format( checkpoint=checkpoint, block=block, coordinate_scale=",".join( [str(x) for x in coordinate_scale]), ), edge_attrs={ mst_0: [distance_attr], mst_dense_0: [distance_attr], mst_1: [distance_attr], mst_dense_1: [distance_attr], # mst_2: [distance_attr], # optimal_mst: [distance_attr], # it is unclear how to add distances if using connectivity graph # mst_dense_2: [distance_attr], details: ["details", "label_pair"], }, node_attrs={details: ["details", "label_pair"]}, additional_request=additional_request, ) validation_pipelines.append(pipeline) final_score = gp.ArrayKey("SCORE") validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeScores(final_score, specs) + gp.PrintProfilingStats()) return validation_pipeline, final_score
def pre_computed_fg_validation_pipeline(config, snapshot_file, raw_path, gt_path, fg_path): blocks = config["BLOCKS"] benchmark_datasets_path = Path(config["BENCHMARK_DATA_PATH"]) sample = config["VALIDATION_SAMPLES"][0] transform_template = "/nrs/mouselight/SAMPLES/{sample}/transform.txt" voxel_size = gp.Coordinate(config["VOXEL_SIZE"]) input_shape = gp.Coordinate(config["INPUT_SHAPE"]) output_shape = gp.Coordinate(config["OUTPUT_SHAPE"]) input_size = voxel_size * input_shape output_size = voxel_size * output_shape candidate_spacing = config["CANDIDATE_SPACING"] candidate_threshold = config["CANDIDATE_THRESHOLD"] distance_attr = config["DISTANCE_ATTR"] num_thresholds = config["NUM_EVAL_THRESHOLDS"] threshold_range = config["EVAL_THRESHOLD_RANGE"] component_threshold = config["COMPONENT_THRESHOLD_1"] validation_pipelines = [] specs = {} for block in blocks: validation_dir = get_validation_dir(benchmark_datasets_path, block) trees = [] cube = None for gt_file in validation_dir.iterdir(): if gt_file.name[0:4] == "tree" and gt_file.name[-4:] == ".swc": trees.append(gt_file) if gt_file.name[0:4] == "cube" and gt_file.name[-4:] == ".swc": cube = gt_file assert cube.exists() cube_roi = get_roi_from_swc( cube, Path(transform_template.format(sample=sample)), np.array(voxel_size[::-1]), ) candidates = gp.ArrayKey(f"CANDIDATES_{block}") raw = gp.ArrayKey(f"RAW_{block}") mst = gp.GraphKey(f"MST_{block}") gt = gp.GraphKey(f"GT_{block}") fg = gp.ArrayKey(f"FG_{block}") score = gp.ArrayKey(f"SCORE_{block}") details = gp.GraphKey(f"DETAILS_{block}") raw_source = SnapshotSource( snapshot_file, datasets={ raw: raw_path.format(block=block), fg: fg_path.format(block=block), }, ) gt_source = SnapshotSource( snapshot_file, datasets={gt: gt_path.format(block=block)}, directed={gt: False}, ) input_roi = cube_roi.grow((input_size - output_size) // 2, (input_size - output_size) // 2) cube_roi_shifted = gp.Roi((0, ) * len(cube_roi.get_shape()), cube_roi.get_shape()) input_roi = cube_roi_shifted.grow((input_size - output_size) // 2, (input_size - output_size) // 2) block_spec = specs.setdefault(block, {}) block_spec[raw] = gp.ArraySpec(input_roi) block_spec[candidates] = gp.ArraySpec(cube_roi_shifted) block_spec[fg] = gp.ArraySpec(cube_roi_shifted) block_spec[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[mst] = gp.GraphSpec(cube_roi_shifted, directed=False) block_spec[score] = gp.ArraySpec(nonspatial=True) additional_request = BatchRequest() additional_request[raw] = gp.ArraySpec(input_roi) additional_request[candidates] = gp.ArraySpec(cube_roi_shifted) additional_request[fg] = gp.ArraySpec(cube_roi_shifted) additional_request[gt] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[mst] = gp.GraphSpec(cube_roi_shifted, directed=False) additional_request[details] = gp.GraphSpec(cube_roi_shifted, directed=False) pipeline = ((raw_source, gt_source) + gp.MergeProvider() + Skeletonize( fg, candidates, candidate_spacing, candidate_threshold) + MiniMax(fg, candidates, mst, distance_attr=distance_attr)) pipeline += Evaluate( gt, mst, score, roi=cube_roi_shifted, details=details, edge_threshold_attr=distance_attr, num_thresholds=num_thresholds, threshold_range=threshold_range, small_component_threshold=component_threshold, ) if config["EVAL_SNAPSHOT"]: pipeline += gp.Snapshot( { raw: f"volumes/raw", fg: f"volumes/foreground", candidates: f"volumes/candidates", mst: f"points/mst", gt: f"points/gt", details: f"points/details", }, output_dir="eval_results", output_filename=config["EVAL_SNAPSHOT_NAME"].format( block=block), edge_attrs={ mst: [distance_attr], details: ["details", "label_pair"] }, node_attrs={details: ["details", "label_pair"]}, additional_request=additional_request, ) validation_pipelines.append(pipeline) final_score = gp.ArrayKey("SCORE") validation_pipeline = (tuple(pipeline for pipeline in validation_pipelines) + gp.MergeProvider() + MergeScores(final_score, specs) + gp.PrintProfilingStats()) return validation_pipeline, final_score