def process(self, batch, request): outputs = gp.Batch() gt_graph = nx.Graph() mst_graph = nx.Graph() for block, block_specs in self.specs.items(): ground_truth_key = block_specs["ground_truth"][0] mst_key = block_specs["mst_pred"][0] block_gt_graph = batch[ground_truth_key].to_nx_graph( ).to_undirected() block_mst_graph = batch[mst_key].to_nx_graph().to_undirected() gt_graph = nx.disjoint_union(gt_graph, block_gt_graph) mst_graph = nx.disjoint_union(mst_graph, block_mst_graph) for node, attrs in gt_graph.nodes.items(): attrs["id"] = node for node, attrs in mst_graph.nodes.items(): attrs["id"] = node outputs[self.gt] = gp.Graph.from_nx_graph( gt_graph, gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3), directed=False)) outputs[self.mst] = gp.Graph.from_nx_graph( mst_graph, gp.GraphSpec(roi=gp.Roi((None, ) * 3, (None, ) * 3), directed=False), ) return outputs
def process(self, batch, request): outputs = gp.Batch() outputs[self.array] = copy.deepcopy(batch[self.array]) outputs[self.array].data = ( torch.from_numpy(batch[self.array].data).squeeze(0).numpy() ) return outputs
def process(self, batch, request): outputs = gp.Batch() for array in self.arrays: if array in batch: outputs[array] = copy.deepcopy(batch[array]) outputs[array].data = torch.from_numpy(batch[array].data).squeeze(0).numpy() return outputs
def provide(self, request): outputs = gp.Batch() # RAW raw_spec = copy.deepcopy(self.array_spec_raw) raw_spec.roi = request[self.raw].roi raw_shape = request[self.raw].roi.get_shape() / self.voxel_size outputs[self.raw] = gp.Array( np.random.randint(0, 256, raw_shape, dtype=raw_spec.dtype), raw_spec) # Unsqueeze outputs[self.raw].data = np.expand_dims(outputs[self.raw].data, axis=0) outputs[self.raw].data = np.expand_dims(outputs[self.raw].data, axis=0) # LABELS labels_spec = copy.deepcopy(self.array_spec_labels) labels_spec.roi = request[self.labels].roi labels_shape = request[self.labels].roi.get_shape() / self.voxel_size labels = np.ones(labels_shape, dtype=labels_spec.dtype) outputs[self.labels] = gp.Array(labels, labels_spec) # Unsqueeze outputs[self.labels].data = np.expand_dims(outputs[self.labels].data, axis=0) return outputs
def provide(self, request): timing = gp.profiling.Timing(self) timing.start() batch = gp.Batch() for (array_key, request_spec) in request.array_specs.items(): logger.debug("Reading %s in %s...", array_key, request_spec.roi) voxel_size = self.spec[array_key].voxel_size # scale request roi to voxel units dataset_roi = request_spec.roi / voxel_size # shift request roi into dataset # dataset_roi = dataset_roi - self.spec[array_key].roi.get_offset() / voxel_size # create array spec array_spec = self.spec[array_key].copy() array_spec.roi = request_spec.roi # add array to batch batch.arrays[array_key] = gp.Array( self.func(dataset_roi.get_shape()), array_spec) logger.debug("done") timing.stop() batch.profiling_stats.add(timing) return batch
def provide(self, request): batch = gp.Batch() for (array_key, request_spec) in request.array_specs.items(): array_spec = self.spec[array_key].copy() array_spec.roi = request_spec.roi print "array_spec: ", array_spec.roi.get_shape() data = np.zeros((array_spec.roi.get_shape())) batch.arrays[array_key] = gp.Array(data, array_spec) return batch
def provide(self, request): voxel_size = self.spec[self.raw].voxel_size shape = gp.Coordinate((1, ) + request[self.raw].roi.get_shape()) noise = np.abs(np.random.randn(*shape)) smoothed_noise = gaussian_filter(noise, sigma=self.smoothness) seeds = np.zeros(shape, dtype=int) for i in range(self.n_objects): if i == 0: num_points = 100 else: num_points = self.points_per_skeleton points = np.stack( [ np.random.randint(0, shape[dim], num_points) for dim in range(3) ], axis=1, ) tree = skelerator.Tree(points) skeleton = skelerator.Skeleton(tree, [1, 1, 1], "linear", generate_graph=False) seeds = skeleton.draw(seeds, np.array([0, 0, 0]), i + 1) seeds[maximum_filter(seeds, size=4) != seeds] = 0 seeds_dt = distance_transform_edt(seeds == 0) + 5.0 * smoothed_noise gt_data = cwatershed(seeds_dt, seeds).astype(np.uint64)[0] - 1 labels = np.unique(gt_data) raw_data = np.zeros_like(gt_data, dtype=np.uint8) value = 0 for label in labels: raw_data[gt_data == label] = value value += 255.0 / self.n_objects spec = request[self.raw].copy() spec.voxel_size = (1, 1) raw = gp.Array(raw_data, spec) spec = request[self.gt].copy() spec.voxel_size = (1, 1) gt_crop = (request[self.gt].roi - request[self.raw].roi.get_begin()) / voxel_size gt_crop = gt_crop.to_slices() gt = gp.Array(gt_data[gt_crop], spec) batch = gp.Batch() batch[self.raw] = raw batch[self.gt] = gt return batch
def provide(self, request): roi = request[self.graph_key].roi random_points = self.random_point_generator.get_random_points(roi) batch = gp.Batch() batch[self.graph_key] = gp.Graph( [gp.Node(id=i, location=l) for i, l in random_points.items()], [], gp.GraphSpec(roi=roi, directed=False)) return batch
def provide(self, request): roi_array = request[gp.ArrayKeys.M_PRED].roi batch = gp.Batch() batch.arrays[gp.ArrayKeys.M_PRED] = gp.Array( self.m_pred[(roi_array / self.voxel_size).to_slices()], spec=gp.ArraySpec(roi=roi_array, voxel_size=self.voxel_size)) slices = (roi_array / self.voxel_size).to_slices() batch.arrays[gp.ArrayKeys.D_PRED] = gp.Array( self.d_pred[:, slices[0], slices[1], slices[2]], spec=gp.ArraySpec(roi=roi_array, voxel_size=self.voxel_size)) return batch
def process(self, batch, request): outputs = gp.Batch() if self.in_array not in batch: return data = batch[self.in_array].data spec = batch[self.in_array].spec.copy() spec.dtype = np.bool binarized = data != self.target outputs[self.out_array] = gp.Array(binarized, spec) return outputs
def process(self, batch, request): final_scores = {} for key, array in batch.items(): if "SCORE" in str(key): block = int(str(key).split("_")[1]) final_scores[block] = array.data final_scores = [ final_scores[block] for block in range(1, 26) if block in final_scores ] outputs = gp.Batch() outputs[self.output] = gp.Array(np.array(final_scores), gp.ArraySpec(nonspatial=True)) return outputs
def process(self, batch, request): array = batch.arrays[self.array] array.data = filters.gaussian(array.data, sigma=self.sigma, mode='constant', preserve_range=True, multichannel=False) batch = gp.Batch() batch[self.array] = array.crop(request[self.array].roi) return batch
def process(self, batch, request): outputs = gp.Batch() # logger.debug("upsampeling %s with %s", self.source, self.factor) # resize data = batch.arrays[self.source].data data = rescale(data, self.factor) # create output array spec = self.spec[self.target].copy() spec.roi = request[self.target].roi outputs.arrays[self.target] = gp.Array(data, spec) return outputs
def provide(self, request): timing = Timing(self) timing.start() batch = gp.Batch() # If a Array is requested then we will randomly choose # the number of requested points if isinstance(self.points, gp.ArrayKey): points = np.random.choice(self.data.shape[0], self.num_points) data = self.data[points][np.newaxis] if self.scale is not None: data = data * self.scale if self.label_data is not None: labels = self.label_data[points] batch[self.points] = gp.Array(data, self.spec[self.points]) else: # If a graph is request we must select points within the # request ROI min_bb = request[self.points].roi.get_begin() max_bb = request[self.points].roi.get_end() logger.debug("Points source got request for %s", request[self.points].roi) point_filter = np.ones((self.data.shape[0], ), dtype=np.bool) for d in range(self.ndims): point_filter = np.logical_and(point_filter, self.data[:, d] >= min_bb[d]) point_filter = np.logical_and(point_filter, self.data[:, d] < max_bb[d]) points_data, labels = self._get_points(point_filter) logger.debug(f"Found {len(points_data)} points") points_spec = gp.GraphSpec(roi=request[self.points].roi.copy()) batch.graphs[self.points] = gp.Graph(points_data, [], points_spec) # Labels will always be an Array if self.label_data is not None: batch[self.labels] = gp.Array(labels, self.spec[self.labels]) timing.stop() batch.profiling_stats.add(timing) return batch
def process(self, batch, request): output = gp.Batch() gt_array = NumpyArray.from_gp_array(batch[self.gt_key]) target_array = self.predictor.create_target(gt_array) mask_array = NumpyArray.from_gp_array(batch[self.mask_key]) weight_array = self.predictor.create_weight( gt_array, target_array, mask=mask_array ) request_spec = request[self.target_key] request_spec.voxel_size = gt_array.voxel_size output[self.target_key] = gp.Array(target_array[request_spec.roi], request_spec) request_spec = request[self.weights_key] request_spec.voxel_size = gt_array.voxel_size output[self.weights_key] = gp.Array( weight_array[request_spec.roi], request_spec ) return output
def provide(self, request): voxel_size = self.spec[self.raw].voxel_size shape = gp.Coordinate((1, ) + request[self.raw].roi.get_shape()) gt_data = np.zeros(shape, dtype=int) for i in range(self.n_objects): points = np.stack( [np.random.randint(0, shape[dim], 2) for dim in range(3)], axis=1) tree = skelerator.Tree(points) skeleton = skelerator.Skeleton(tree, [1, 1, 1], "linear", generate_graph=False) gt_data = skeleton.draw(gt_data, np.array([0, 0, 0]), i + 1) gt_data = gt_data[0].astype(np.uint64) gt_data = maximum_filter(gt_data, size=2) labels = np.unique(gt_data) raw_data = (gt_data > 0).astype(np.float32) raw_data = np.clip( raw_data + np.random.normal(scale=0.1, size=raw_data.shape), 0, 1).astype(np.float32) spec = request[self.raw].copy() spec.voxel_size = (1, 1) raw = gp.Array(raw_data, spec) spec = request[self.gt].copy() spec.voxel_size = (1, 1) gt_crop = (request[self.gt].roi - request[self.raw].roi.get_begin()) / voxel_size gt_crop = gt_crop.to_slices() gt = gp.Array(gt_data[gt_crop], spec) batch = gp.Batch() batch[self.raw] = raw batch[self.gt] = gt return batch
def provide(self, request): batch = gp.Batch() # print "n:", self.n # print "pid: ", mp.current_process().pid for (array_key, request_spec) in request.array_specs.items(): array_spec = self.spec[array_key].copy() array_spec.roi = request_spec.roi shape = array_spec.roi.get_shape() # enlarge lshape = list(shape) inc = [0] * len(shape) for i, s in enumerate(shape): if s % 2 != 0: inc[i] += 1 lshape[i] += 1 shape = gp.Coordinate(lshape) data = create_segmentation( shape=shape, n_objects=self.n_objects, points_per_skeleton=self.points_per_skeleton, interpolation=self.interpolation, smoothness=self.smoothness, noise_strength=self.noise_strength, seed=self.seed) # seed=np.random.randint(10000)) segmentation = data["segmentation"] # crop (more elegant & general way to do this?) segmentation = segmentation[:lshape[0] - inc[0], :lshape[1] - inc[1], :lshape[2] - inc[2]] # segmentation = segmentation[:lshape_out[i] - inc[i] for i in range(len(shape))] batch.arrays[array_key] = gp.Array(segmentation, array_spec) # self.n +=1 return batch
def provide(self, request): output = gp.Batch() timing_provide = Timing(self, "provide") timing_provide.start() spec = self.array_spec.copy() spec.roi = request[self.key].roi data = self.array[spec.roi] if "c" not in self.array.axes: # add a channel dimension data = np.expand_dims(data, 0) if np.any(np.isnan(data)): raise ValueError("INPUT DATA CAN'T BE NAN") output[self.key] = gp.Array(data, spec=spec) timing_provide.stop() output.profiling_stats.add(timing_provide) return output
def process(self, batch, request): # get the raw and segmentation arrays from the current batch raw = batch[self.raw] seg = batch[self.seg] print(f"RAW: {raw}") print(f"SEG: {seg}") # simulate cages, return brembow volumes for raw, cages, and density simulated_raw = Volume(raw.data, raw.spec.voxel_size) cage_map, density_map = simulate_random_cages( simulated_raw, Volume(seg.data, seg.spec.voxel_size), self.cages, self.min_density, self.max_density, self.psf, True, True, self.no_cage_probability) # create array specs for new gunpowder arrays raw_spec = batch[self.raw].spec.copy() cage_map_spec = batch[self.seg].spec.copy() cage_map_spec.dtype = np.uint64 density_map_spec = batch[self.seg].spec.copy() density_map_spec.dtype = np.float32 # create arrays and crop to requested size print(cage_map_spec) cage_map_array = gp.Array(data=cage_map, spec=cage_map_spec) cage_map_array = cage_map_array.crop(request[self.cage_map].roi) density_map_array = gp.Array(data=density_map, spec=density_map_spec) density_map_array = density_map_array.crop( request[self.density_map].roi) # create a new batch with processed arrays processed = gp.Batch() processed[self.raw] = gp.Array(data=simulated_raw.data, spec=raw_spec) processed[self.cage_map] = cage_map_array processed[self.density_map] = density_map_array return processed
def process(self, batch, request): outputs = gp.Batch() graph = batch[self.points] full_roi = graph.spec.roi size = full_roi.get_shape() small_roi = full_roi.copy() if self.centroid_size is not None: diff = self.centroid_size - size diff = diff / gp.Coordinate([2] * len(diff)) small_roi = small_roi.grow(diff, diff) centered_graph = graph.crop(small_roi) wccs = list(graph.connected_components) for wcc in wccs: fallbacks = [x < self.node_offset for x in wcc] contained = [centered_graph.contains(x) for x in wcc] if not all([a or not b for a, b in zip(fallbacks, contained)]): for node in wcc: graph.remove_node(gp.Node(id=node, location=None)) outputs[self.points] = graph
def process(self, batch, request): # compute stardists on label data data = batch.arrays[self.label_key].data tmp = star_dist3d_custom(data, self.rays, self.unlabeled_id, self.max_dist, invalid_value=self.invalid_value, grid=self.grid, voxel_size=self.anisotropy, mode=self.sd_mode) # seems unnecessary when using grid in function call above # tmp = tmp[self.ss_grid] dist = np.moveaxis(tmp, -1, 0) # gp expects channel axis in front # generate spec for new batch based on what's coming in for labels spec = self._updated_spec(batch[self.label_key].spec) spec.roi = request[self.stardist_key].roi.copy() # assemble new array in a batch, will be added to existing batch automatically batch = gp.Batch() batch[self.stardist_key] = gp.Array(dist, spec) return batch
def provide(self, request): outputs = gp.Batch() if str(self.snapshot_file).endswith(".h5") or str(self.snapshot_file).endswith( ".hdf" ): data = h5py.File(self.snapshot_file, "r") elif str(self.snapshot_file).endswith(".zarr"): data = zarr.open(self.snapshot_file, "r") for key, path in self.datasets.items(): if isinstance(key, gp.ArrayKey): result = self.array_from_path(data, path) outputs[key] = result elif isinstance(key, gp.GraphKey): result = self.graph_from_path(key, data, path) result.relabel_connected_components() logger.debug( f"Reading graph {key} with {result.num_vertices()} nodes, " f"{result.num_edges()} edges, and {len(list(result.connected_components))} " f"connected_components" ) outputs[key] = result outputs = outputs.crop(request) return outputs
def process(self, batch, request): outputs = gp.Batch() outputs[self.array] = copy.deepcopy(batch[self.array]) outputs[self.array].data = batch[self.array].data.astype(np.int64) outputs[self.array].spec.dtype = np.int64 return outputs
def process(self, batch, request): num_thresholds = self.num_thresholds threshold_range = self.threshold_range outputs = gp.Batch() gt_graph = batch[self.gt].to_nx_graph().to_undirected() mst_graph = batch[self.mst].to_nx_graph().to_undirected() if self.connectivity is not None: connectivity_graph = batch[ self.connectivity].to_nx_graph().to_undirected() # assert mst_graph.number_of_nodes() > 0, f"mst_graph is empty!" if self.details is not None: matching_details_graph = nx.Graph() if mst_graph.number_of_nodes() == 0: node_offset = max([node for node in mst_graph.nodes] + [-1]) + 1 label_offset = len(list( nx.connected_components(mst_graph))) + 1 for node, attrs in mst_graph.nodes.items(): matching_details_graph.add_node(node, **copy.deepcopy(attrs)) for edge, attrs in mst_graph.edges.items(): matching_details_graph.add_edge(edge[0], edge[1], **copy.deepcopy(attrs)) for node, attrs in gt_graph.nodes.items(): matching_details_graph.add_node(node + node_offset, **copy.deepcopy(attrs)) matching_details_graph.nodes[node + node_offset]["id"] = ( node + node_offset) for edge, attrs in gt_graph.edges.items(): matching_details_graph.add_edge(edge[0] + node_offset, edge[1] + node_offset, **copy.deepcopy(attrs)) edges = [(edge, attrs[self.edge_threshold_attr]) for edge, attrs in mst_graph.edges.items()] edges = list(sorted(edges, key=lambda x: x[1])) edge_lens = [e[1] for e in edges] # min_threshold = edges[0][1] if len(edge_lens) > 0: min_threshold = edge_lens[int(len(edge_lens) * threshold_range[0])] max_threshold = edge_lens[int(len(edge_lens) * threshold_range[1]) - 1] else: min_threshold = 0 max_threshold = 1 thresholds = np.linspace(min_threshold, max_threshold, num=num_thresholds) current_threshold_mst = nx.Graph() edge_deque = deque(edges) edit_distances = [] split_costs = [] merge_costs = [] false_pos_costs = [] false_neg_costs = [] num_nodes = [] num_edges = [] best_score = None best_graph = None for threshold_index, threshold in enumerate(thresholds): logger.warning(f"Using threshold: {threshold}") while len(edge_deque) > 0 and edge_deque[0][1] <= threshold: (u, v), _ = edge_deque.popleft() attrs = mst_graph.edges[(u, v)] current_threshold_mst.add_edge(u, v) current_threshold_mst.add_node(u, **mst_graph.nodes[u]) current_threshold_mst.add_node(v, **mst_graph.nodes[v]) if self.connectivity is not None: temp = nx.Graph() next_node = max([node for node in connectivity_graph.nodes]) + 1 for i, cc in enumerate( nx.connected_components(current_threshold_mst)): component_subgraph = current_threshold_mst.subgraph(cc) for node in component_subgraph.nodes: temp.add_node(node, **dict(connectivity_graph.nodes[node])) temp.nodes[node]["component"] = i for edge in connectivity_graph.edges: if (edge[0] in temp.nodes and edge[1] in temp.nodes and temp.nodes[edge[0]]["component"] == temp.nodes[edge[1]]["component"]): temp.add_edge( edge[0], edge[1], **dict(connectivity_graph.edges[edge])) elif False: path = nx.shortest_path(connectivity_graph, edge[0], edge[1]) cloned_path = [] for node in path: if node in temp.nodes: cloned_path.append(node) else: cloned_path.append(next_node) next_node += 1 path_len = len(cloned_path) - 1 for i, j in zip(range(path_len), range(1, path_len + 1)): u = cloned_path[i] if u not in temp.nodes: temp.add_node( u, **dict( connectivity_graph.nodes[path[i]])) v = cloned_path[j] if v not in temp.nodes: temp.add_node( v, **dict( connectivity_graph.nodes[path[j]])) temp.add_edge( u, v, **dict(connectivity_graph.edges[path[i], path[j]]), ) else: temp = copy.deepcopy(current_threshold_mst) for i, cc in enumerate(nx.connected_components(temp)): for node in cc: attrs = temp.nodes[node] attrs["component"] = i # remove small connected_components false_pos_nodes = [] for cc in nx.connected_components(temp): cc_graph = temp.subgraph(cc) min_loc = None max_loc = None for node, attrs in cc_graph.nodes.items(): node_loc = attrs[self.location_attr] if min_loc is None: min_loc = node_loc else: min_loc = np.min(np.array([node_loc, min_loc]), axis=0) if max_loc is None: max_loc = node_loc else: max_loc = np.max(np.array([node_loc, max_loc]), axis=0) if np.linalg.norm(min_loc - max_loc) < self.small_component_threshold: false_pos_nodes += list(cc) for node in false_pos_nodes: temp.remove_node(node) nodes_x = list(temp.nodes) nodes_y = list(gt_graph.nodes) node_labels_x = { node: attrs["component"] for node, attrs in temp.nodes.items() } node_labels_y = { node: component for component, nodes in enumerate( nx.connected_components(gt_graph)) for node in nodes } edges_yx = get_edges_xy( gt_graph, temp, location_attr=self.location_attr, node_match_threshold=self.comatch_threshold, ) edges_xy = [(v, u) for u, v in edges_yx] result = match_components( nodes_x, nodes_y, edges_xy, copy.deepcopy(node_labels_x), copy.deepcopy(node_labels_y), ) if self.details is not None: # add a match details graph to the batch # details is a graph containing nodes from both mst and gt # to access details of a node, use `details.nodes[node]["details"]` # where the details returned are a numpy array of shape (num_thresholds, _). # the _ values stored per threshold are success, fp, fn, merge, split, selected, mst, gt # NODES # success, mst: n matches to only nodes of 1 label, which matches its own label # success, gt: n matches to only nodes of 1 label, which matches its own label # fp: n in mst matches to nothing # fn: n in gt matches to nothing # merge: n in mst matches to a node with label not matching its own # split: n in gt matches to a node with label not matching its own # selected: n in mst in thresholded graph # EDGES # success, mst: both endpoints successful # success, gt: both endpoints successful # fp: both endpoints fp # fn: both endpoints fn # merge: e in mst: only one endpoint successful # split: e in gt: only one endpoint successful # selected: e in mst in thresholded graph (label_matches, node_matches, splits, merges, fps, fns) = result # create lookup tables: x_label_match_lut = {} y_label_match_lut = {} for a, b in label_matches: x_matches = x_label_match_lut.setdefault(a, set()) x_matches.add(b) y_matches = y_label_match_lut.setdefault(b, set()) y_matches.add(a) x_node_match_lut = {} y_node_match_lut = {} for a, b in node_matches: x_matches = x_node_match_lut.setdefault(a, set()) x_matches.add(b) y_matches = y_node_match_lut.setdefault(b, set()) y_matches.add(a) for node, attrs in matching_details_graph.nodes.items(): gt = int(node >= node_offset) mst = 1 - gt if gt == 1: node = node - node_offset selected = gt or (node in temp.nodes()) if selected: success, fp, fn, merge, split, label_pair = self.node_matching_result( node, gt, x_label_match_lut, y_label_match_lut, x_node_match_lut, y_node_match_lut, node_labels_x, node_labels_y, ) else: success, fp, fn, merge, split, label_pair = ( 0, 0, 0, 0, 0, (-1, -1), ) data = attrs.setdefault( "details", np.zeros((len(thresholds), 7), dtype=bool)) data[threshold_index] = [ selected, success, fp, fn, merge, split, gt, ] label_pairs = attrs.setdefault("label_pair", []) label_pairs.append(label_pair) assert len(label_pairs) == threshold_index + 1 for (u, v), attrs in matching_details_graph.edges.items(): ( u_selected, u_success, u_fp, u_fn, u_merge, u_split, u_gt, ) = matching_details_graph.nodes[u]["details"][ threshold_index] ( v_selected, v_success, v_fp, v_fn, v_merge, v_split, v_gt, ) = matching_details_graph.nodes[v]["details"][ threshold_index] assert u_gt == v_gt e_gt = u_gt u_label_pair = matching_details_graph.nodes[u][ "label_pair"][threshold_index] v_label_pair = matching_details_graph.nodes[v][ "label_pair"][threshold_index] e_selected = u_selected and v_selected e_success = (e_selected and u_success and v_success and (u_label_pair == v_label_pair)) e_fp = u_fp and v_fp e_fn = u_fn and v_fn e_merge = e_selected and (not e_success) and ( not e_fp) and not e_gt e_split = e_selected and (not e_success) and ( not e_fn) and e_gt assert not (e_success and e_merge) assert not (e_success and e_split) data = attrs.setdefault( "details", np.zeros((len(thresholds), 7), dtype=bool)) if e_success: label_pairs = attrs.setdefault("label_pair", []) label_pairs.append(u_label_pair) assert len(label_pairs) == threshold_index + 1 else: label_pairs = attrs.setdefault("label_pair", []) label_pairs.append((-1, -1)) assert len(label_pairs) == threshold_index + 1 data[threshold_index] = [ e_selected, e_success, e_fp, e_fn, e_merge, e_split, e_gt, ] edit_distance, ( split_cost, merge_cost, false_pos_cost, false_neg_cost, ) = psudo_graph_edit_distance( result[1], node_labels_x, node_labels_y, temp, gt_graph, self.location_attr, node_spacing=self.edit_distance_node_spacing, details=True, ) edit_distances.append(edit_distance) split_costs.append(split_cost) merge_costs.append(merge_cost) false_pos_costs.append(false_pos_cost) false_neg_costs.append(false_neg_cost) num_nodes.append(len(temp.nodes)) num_edges.append(len(temp.edges)) # save the best version: if best_score is None: best_score = edit_distance best_graph = copy.deepcopy(temp) elif edit_distance < best_score: best_score = edit_distance best_graph = copy.deepcopy(temp) outputs[self.output] = gp.Array( np.array([ edit_distances, thresholds, num_nodes, num_edges, split_costs, merge_costs, false_pos_costs, false_neg_costs, ]), gp.ArraySpec(nonspatial=True), ) if self.output_graph is not None: outputs[self.output_graph] = gp.Graph.from_nx_graph( best_graph, gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False)) if self.details is not None: outputs[self.details] = gp.Graph.from_nx_graph( matching_details_graph, gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False), ) return outputs