def runTest(self): print "ManyToMany" mg = MatchingGraph(self.groundtruth_skeletons, self.reconstructed_skeletons, self.distance_threshold, self.voxel_size, verbose=True, distance_cost=True, initialize_all=True) nodes_gt, nodes_rec, edges_gt_rec, labels_gt, labels_rec, edge_costs, edge_conflicts, edge_pairs = mg.export_to_comatch() try: # Quadmatch label_matches, node_matches, num_splits, num_merges, num_fps, num_fns = match_components(nodes_gt, nodes_rec, edges_gt_rec, labels_gt, labels_rec, edge_conflicts=edge_conflicts, max_edges=10, edge_costs=edge_costs) except TypeError: # Comatch label_matches, node_matches, num_splits, num_merges, num_fps, num_fns = match_components(nodes_gt, nodes_rec, edges_gt_rec, labels_gt, labels_rec, allow_many_to_many=True, edge_costs=edge_costs, no_match_costs=1000.) print "label matches:", label_matches print "node_matches:", node_matches comatch_errors = {"splits": num_splits, "num_merges": num_merges, "num_fps": num_fps, "num_fns": num_fns} print comatch_errors mg.import_matches(node_matches) output_dir = test_data_dir + "/MatchingManytoMany" mg.export_all(output_dir) with open(output_dir + "/macro_errors.json", "w+") as f: json.dump(comatch_errors, f)
def test_match(): nodes_x = list(range(1, 8)) nodes_y = list(range(101, 111)) edges_xy = [ (1, 101), (2, 102), (2, 103), (2, 104), (2, 109), (3, 103), (3, 108), (4, 104), (4, 109), (5, 105), (5, 110), (6, 106), (7, 107), ] node_labels_x = {n: 1 for n in nodes_x} node_labels_y = {n: 2 for n in nodes_y} node_labels_y[108] = 3 node_labels_y[109] = 3 node_labels_y[110] = 3 edge_conflicts = [[(3, 103), (3, 108)], [(4, 104), (4, 109)], [(5, 105), (5, 110)], [(2, 109), (2, 104)], [(2, 109), (2, 103)], [(2, 102), (2, 109)]] label_matches, node_matches, splits, merges, fps, fns = comatch.match_components( nodes_x, nodes_y, edges_xy, node_labels_x, node_labels_y, edge_conflicts=edge_conflicts) print(node_matches) print("splits: %d" % splits) print("merges: %d" % merges) print("fps : %d" % fps) print("fns : %d" % fns) # the other way around label_matches, node_matches, splits, merges, fps, fns = comatch.match_components( nodes_y, nodes_x, [(v, u) for (u, v) in edges_xy], node_labels_y, node_labels_x, edge_conflicts=[[tuple([c[0][1], c[0][0]]), tuple([c[1][1], c[1][0]])] for c in edge_conflicts]) print(node_matches) print("splits: %d" % splits) print("merges: %d" % merges) print("fps : %d" % fps) print("fns : %d" % fns)
def runTest(self): print "Import matches" mg = MatchingGraph(self.groundtruth_skeletons, self.reconstructed_skeletons, self.distance_threshold, self.voxel_size, verbose=True, initialize_all=True) nodes_gt, nodes_rec, edges_gt_rec, labels_gt, labels_rec, edge_costs, edge_conflicts, edge_pairs = mg.export_to_comatch() label_matches, node_matches, num_splits, num_merges, num_fps, num_fns = match_components(nodes_gt, nodes_rec, edges_gt_rec, labels_gt, labels_rec) matches = node_matches # Everything is matched self.assertTrue(len(matches) == mg.get_number_of_vertices()/2) mg.import_matches(matches) for v in mg.get_vertex_iterator(): self.assertTrue(mg.is_tp(v)) self.assertFalse(mg.is_fp(v)) self.assertFalse(mg.is_fn(v)) for e in mg.get_edge_iterator(): self.assertFalse(mg.is_split(e)) self.assertFalse(mg.is_merge(e))
def evaluate_matching_graph(matching_graph, max_edges=1, export_to=None, optimality_gap=0.0, time_limit=None, n_gts=-1, n_recs=-1): if max_edges > 1 or max_edges == None: edge_conflicts = True else: edge_conflicts = False nodes_gt, nodes_rec, edges_gt_rec, labels_gt, labels_rec, edge_costs, edge_conflicts, edge_pairs = matching_graph.export_to_comatch( edge_conflicts=edge_conflicts, edge_pairs=False) logger.info("Match using hungarian match...") label_matches, node_matches, num_splits, num_merges, num_fps, num_fns = match_components( nodes_gt, nodes_rec, edges_gt_rec, labels_gt, labels_rec, edge_conflicts=edge_conflicts, max_edges=max_edges, optimality_gap=optimality_gap, time_limit=time_limit) matching_graph.import_matches(node_matches) topological_errors = { "n_gt": n_gts, "n_rec": n_recs, "splits": num_splits, "merges": num_merges, "fps": num_fps, "fns": num_fns } node_errors = matching_graph.evaluate() if export_to is not None: matching_graph.export_all(export_to) with open(export_to + "/object_stats.txt", "w+") as f: json.dump(topological_errors, f) return matching_graph, topological_errors, node_errors
def score_graph( predicted_tracings: nx.Graph, reference_tracings: nx.Graph, match_threshold: float, location_attr: str, metric: Metric, **metric_kwargs, ): # Match the graphs: edges_xy = get_edges_xy(predicted_tracings, reference_tracings, location_attr, match_threshold) nodes_x = list(predicted_tracings.nodes) nodes_y = list(reference_tracings.nodes) node_labels_x = { node: cc for cc, cc_nodes in enumerate( nx.connected_components(predicted_tracings)) for node in cc_nodes } node_labels_y = { node: cc for cc, cc_nodes in enumerate( nx.connected_components(reference_tracings)) for node in cc_nodes } label_matches, node_matches, splits, merges, fps, fns = comatch.match_components( nodes_x, nodes_y, edges_xy, node_labels_x, node_labels_y) # evaluate the matching return evaluate_matching( metric, node_matches, node_labels_x, node_labels_y, predicted_tracings, reference_tracings, location_attr, **metric_kwargs, )
def add_mst_snapshot_with_stats( context, matching_data, mst_snapshot, gt, threshold_inds=None, graph_node_attrs=None, graph_edge_attrs=None, false_pos_threshold=None, ): if false_pos_threshold is None: false_pos_threshold = 0 mst = tree_from_snapshot( mst_snapshot, "mst", graph_node_attrs=graph_node_attrs, graph_edge_attrs=graph_edge_attrs, ) label_matchings, node_matchings, node_labels_mst, node_labels_gt, thresholds = ( matching_data ) edges_to_add = list(mst.edges.items()) thresholded_graph = nx.Graph() for i, threshold in enumerate(thresholds): for j, ((u, v), attrs) in reversed(list(enumerate(edges_to_add))): if attrs["distance"] < threshold: thresholded_graph.add_node(u, **mst.nodes[u]) thresholded_graph.add_node(v, **mst.nodes[v]) thresholded_graph.add_edge(u, v, **attrs) del edges_to_add[j] if i in threshold_inds: temp = copy.deepcopy(thresholded_graph) false_pos_nodes = [] for cc in nx.connected_components(temp): cc_graph = temp.subgraph(cc) min_loc = None max_loc = None for node, attrs in cc_graph.nodes.items(): node_loc = attrs["location"] if min_loc is None: min_loc = node_loc else: min_loc = np.min(np.array([node_loc, min_loc]), axis=0) if max_loc is None: max_loc = node_loc else: max_loc = np.max(np.array([node_loc, max_loc]), axis=0) if np.linalg.norm(min_loc - max_loc) < false_pos_threshold: false_pos_nodes += list(cc) for node in false_pos_nodes: temp.remove_node(node) nodes_x = list(temp.nodes) nodes_y = list(gt.nodes) node_labels_x = { node: component for component, nodes in enumerate(nx.connected_components(temp)) for node in nodes } node_labels_y = { node: component for component, nodes in enumerate(nx.connected_components(gt)) for node in nodes } edges_yx = get_edges_xy( gt, temp, location_attr="location", node_match_threshold=4000 ) edges_xy = [(v, u) for u, v in edges_yx] (label_matches, node_matches, splits, merges, fps, fns) = match_components( nodes_x, nodes_y, edges_xy, node_labels_x, node_labels_y ) erl, details = psudo_graph_edit_distance( node_matches, node_labels_x, node_labels_y, temp, gt, "location", node_spacing=5000, details=True, ) add_match_layers( context, temp, gt, label_matchings[i], node_matchings[i], node_labels_mst[i], node_labels_gt[i], name=f"Matched-{threshold:.3f}", )
(3, 108), (4, 104), (4, 109), (5, 105), (5, 110), (6, 106), (7, 107), ] node_labels_x = { n: 1 for n in nodes_x } node_labels_y = { n: 2 for n in nodes_y } node_labels_y[108] = 3 node_labels_y[109] = 3 node_labels_y[110] = 3 label_matches, node_matches, splits, merges, fps, fns = comatch.match_components( nodes_x, nodes_y, edges_xy, node_labels_x, node_labels_y) print(node_matches) print("splits: %d"%splits) print("merges: %d"%merges) print("fps : %d"%fps) print("fns : %d"%fns) # the other way around label_matches, node_matches, splits, merges, fps, fns = comatch.match_components( nodes_y, nodes_x, [ (v, u) for (u, v) in edges_xy ], node_labels_y, node_labels_x) print(node_matches)
def process(self, batch, request): num_thresholds = self.num_thresholds threshold_range = self.threshold_range outputs = gp.Batch() gt_graph = batch[self.gt].to_nx_graph().to_undirected() mst_graph = batch[self.mst].to_nx_graph().to_undirected() if self.connectivity is not None: connectivity_graph = batch[ self.connectivity].to_nx_graph().to_undirected() # assert mst_graph.number_of_nodes() > 0, f"mst_graph is empty!" if self.details is not None: matching_details_graph = nx.Graph() if mst_graph.number_of_nodes() == 0: node_offset = max([node for node in mst_graph.nodes] + [-1]) + 1 label_offset = len(list( nx.connected_components(mst_graph))) + 1 for node, attrs in mst_graph.nodes.items(): matching_details_graph.add_node(node, **copy.deepcopy(attrs)) for edge, attrs in mst_graph.edges.items(): matching_details_graph.add_edge(edge[0], edge[1], **copy.deepcopy(attrs)) for node, attrs in gt_graph.nodes.items(): matching_details_graph.add_node(node + node_offset, **copy.deepcopy(attrs)) matching_details_graph.nodes[node + node_offset]["id"] = ( node + node_offset) for edge, attrs in gt_graph.edges.items(): matching_details_graph.add_edge(edge[0] + node_offset, edge[1] + node_offset, **copy.deepcopy(attrs)) edges = [(edge, attrs[self.edge_threshold_attr]) for edge, attrs in mst_graph.edges.items()] edges = list(sorted(edges, key=lambda x: x[1])) edge_lens = [e[1] for e in edges] # min_threshold = edges[0][1] if len(edge_lens) > 0: min_threshold = edge_lens[int(len(edge_lens) * threshold_range[0])] max_threshold = edge_lens[int(len(edge_lens) * threshold_range[1]) - 1] else: min_threshold = 0 max_threshold = 1 thresholds = np.linspace(min_threshold, max_threshold, num=num_thresholds) current_threshold_mst = nx.Graph() edge_deque = deque(edges) edit_distances = [] split_costs = [] merge_costs = [] false_pos_costs = [] false_neg_costs = [] num_nodes = [] num_edges = [] best_score = None best_graph = None for threshold_index, threshold in enumerate(thresholds): logger.warning(f"Using threshold: {threshold}") while len(edge_deque) > 0 and edge_deque[0][1] <= threshold: (u, v), _ = edge_deque.popleft() attrs = mst_graph.edges[(u, v)] current_threshold_mst.add_edge(u, v) current_threshold_mst.add_node(u, **mst_graph.nodes[u]) current_threshold_mst.add_node(v, **mst_graph.nodes[v]) if self.connectivity is not None: temp = nx.Graph() next_node = max([node for node in connectivity_graph.nodes]) + 1 for i, cc in enumerate( nx.connected_components(current_threshold_mst)): component_subgraph = current_threshold_mst.subgraph(cc) for node in component_subgraph.nodes: temp.add_node(node, **dict(connectivity_graph.nodes[node])) temp.nodes[node]["component"] = i for edge in connectivity_graph.edges: if (edge[0] in temp.nodes and edge[1] in temp.nodes and temp.nodes[edge[0]]["component"] == temp.nodes[edge[1]]["component"]): temp.add_edge( edge[0], edge[1], **dict(connectivity_graph.edges[edge])) elif False: path = nx.shortest_path(connectivity_graph, edge[0], edge[1]) cloned_path = [] for node in path: if node in temp.nodes: cloned_path.append(node) else: cloned_path.append(next_node) next_node += 1 path_len = len(cloned_path) - 1 for i, j in zip(range(path_len), range(1, path_len + 1)): u = cloned_path[i] if u not in temp.nodes: temp.add_node( u, **dict( connectivity_graph.nodes[path[i]])) v = cloned_path[j] if v not in temp.nodes: temp.add_node( v, **dict( connectivity_graph.nodes[path[j]])) temp.add_edge( u, v, **dict(connectivity_graph.edges[path[i], path[j]]), ) else: temp = copy.deepcopy(current_threshold_mst) for i, cc in enumerate(nx.connected_components(temp)): for node in cc: attrs = temp.nodes[node] attrs["component"] = i # remove small connected_components false_pos_nodes = [] for cc in nx.connected_components(temp): cc_graph = temp.subgraph(cc) min_loc = None max_loc = None for node, attrs in cc_graph.nodes.items(): node_loc = attrs[self.location_attr] if min_loc is None: min_loc = node_loc else: min_loc = np.min(np.array([node_loc, min_loc]), axis=0) if max_loc is None: max_loc = node_loc else: max_loc = np.max(np.array([node_loc, max_loc]), axis=0) if np.linalg.norm(min_loc - max_loc) < self.small_component_threshold: false_pos_nodes += list(cc) for node in false_pos_nodes: temp.remove_node(node) nodes_x = list(temp.nodes) nodes_y = list(gt_graph.nodes) node_labels_x = { node: attrs["component"] for node, attrs in temp.nodes.items() } node_labels_y = { node: component for component, nodes in enumerate( nx.connected_components(gt_graph)) for node in nodes } edges_yx = get_edges_xy( gt_graph, temp, location_attr=self.location_attr, node_match_threshold=self.comatch_threshold, ) edges_xy = [(v, u) for u, v in edges_yx] result = match_components( nodes_x, nodes_y, edges_xy, copy.deepcopy(node_labels_x), copy.deepcopy(node_labels_y), ) if self.details is not None: # add a match details graph to the batch # details is a graph containing nodes from both mst and gt # to access details of a node, use `details.nodes[node]["details"]` # where the details returned are a numpy array of shape (num_thresholds, _). # the _ values stored per threshold are success, fp, fn, merge, split, selected, mst, gt # NODES # success, mst: n matches to only nodes of 1 label, which matches its own label # success, gt: n matches to only nodes of 1 label, which matches its own label # fp: n in mst matches to nothing # fn: n in gt matches to nothing # merge: n in mst matches to a node with label not matching its own # split: n in gt matches to a node with label not matching its own # selected: n in mst in thresholded graph # EDGES # success, mst: both endpoints successful # success, gt: both endpoints successful # fp: both endpoints fp # fn: both endpoints fn # merge: e in mst: only one endpoint successful # split: e in gt: only one endpoint successful # selected: e in mst in thresholded graph (label_matches, node_matches, splits, merges, fps, fns) = result # create lookup tables: x_label_match_lut = {} y_label_match_lut = {} for a, b in label_matches: x_matches = x_label_match_lut.setdefault(a, set()) x_matches.add(b) y_matches = y_label_match_lut.setdefault(b, set()) y_matches.add(a) x_node_match_lut = {} y_node_match_lut = {} for a, b in node_matches: x_matches = x_node_match_lut.setdefault(a, set()) x_matches.add(b) y_matches = y_node_match_lut.setdefault(b, set()) y_matches.add(a) for node, attrs in matching_details_graph.nodes.items(): gt = int(node >= node_offset) mst = 1 - gt if gt == 1: node = node - node_offset selected = gt or (node in temp.nodes()) if selected: success, fp, fn, merge, split, label_pair = self.node_matching_result( node, gt, x_label_match_lut, y_label_match_lut, x_node_match_lut, y_node_match_lut, node_labels_x, node_labels_y, ) else: success, fp, fn, merge, split, label_pair = ( 0, 0, 0, 0, 0, (-1, -1), ) data = attrs.setdefault( "details", np.zeros((len(thresholds), 7), dtype=bool)) data[threshold_index] = [ selected, success, fp, fn, merge, split, gt, ] label_pairs = attrs.setdefault("label_pair", []) label_pairs.append(label_pair) assert len(label_pairs) == threshold_index + 1 for (u, v), attrs in matching_details_graph.edges.items(): ( u_selected, u_success, u_fp, u_fn, u_merge, u_split, u_gt, ) = matching_details_graph.nodes[u]["details"][ threshold_index] ( v_selected, v_success, v_fp, v_fn, v_merge, v_split, v_gt, ) = matching_details_graph.nodes[v]["details"][ threshold_index] assert u_gt == v_gt e_gt = u_gt u_label_pair = matching_details_graph.nodes[u][ "label_pair"][threshold_index] v_label_pair = matching_details_graph.nodes[v][ "label_pair"][threshold_index] e_selected = u_selected and v_selected e_success = (e_selected and u_success and v_success and (u_label_pair == v_label_pair)) e_fp = u_fp and v_fp e_fn = u_fn and v_fn e_merge = e_selected and (not e_success) and ( not e_fp) and not e_gt e_split = e_selected and (not e_success) and ( not e_fn) and e_gt assert not (e_success and e_merge) assert not (e_success and e_split) data = attrs.setdefault( "details", np.zeros((len(thresholds), 7), dtype=bool)) if e_success: label_pairs = attrs.setdefault("label_pair", []) label_pairs.append(u_label_pair) assert len(label_pairs) == threshold_index + 1 else: label_pairs = attrs.setdefault("label_pair", []) label_pairs.append((-1, -1)) assert len(label_pairs) == threshold_index + 1 data[threshold_index] = [ e_selected, e_success, e_fp, e_fn, e_merge, e_split, e_gt, ] edit_distance, ( split_cost, merge_cost, false_pos_cost, false_neg_cost, ) = psudo_graph_edit_distance( result[1], node_labels_x, node_labels_y, temp, gt_graph, self.location_attr, node_spacing=self.edit_distance_node_spacing, details=True, ) edit_distances.append(edit_distance) split_costs.append(split_cost) merge_costs.append(merge_cost) false_pos_costs.append(false_pos_cost) false_neg_costs.append(false_neg_cost) num_nodes.append(len(temp.nodes)) num_edges.append(len(temp.edges)) # save the best version: if best_score is None: best_score = edit_distance best_graph = copy.deepcopy(temp) elif edit_distance < best_score: best_score = edit_distance best_graph = copy.deepcopy(temp) outputs[self.output] = gp.Array( np.array([ edit_distances, thresholds, num_nodes, num_edges, split_costs, merge_costs, false_pos_costs, false_neg_costs, ]), gp.ArraySpec(nonspatial=True), ) if self.output_graph is not None: outputs[self.output_graph] = gp.Graph.from_nx_graph( best_graph, gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False)) if self.details is not None: outputs[self.details] = gp.Graph.from_nx_graph( matching_details_graph, gp.GraphSpec(roi=batch[self.gt].spec.roi, directed=False), ) return outputs
def evaluate(matching_graph, max_edges=1, optimality_gap=0.0, time_limit=None, n_gts=-1, n_recs=-1): nodes_gt, nodes_rec, labels_gt, labels_rec, edges_gt_rec, edge_conflicts = matching_graph.export( ) if nodes_gt and nodes_rec: label_matches, node_matches, num_splits, num_merges, num_fps, num_fns = match_components( nodes_gt, nodes_rec, edges_gt_rec, labels_gt, labels_rec, edge_conflicts=edge_conflicts, max_edges=max_edges, optimality_gap=optimality_gap, time_limit=time_limit) topological_errors = { "n_gt": len(set(v for v in labels_gt.values())), "n_rec": len(set([v for v in labels_rec.values()])), "splits": num_splits, "merges": num_merges, "fps": num_fps, "fns": num_fns } matching_graph.import_node_matches(node_matches) node_errors = matching_graph.get_stats() else: if not nodes_gt: fps = matching_graph.rec_ccs fns = 0 nodes_fps = len(nodes_rec) nodes_fns = 0 if not nodes_rec: fps = 0 fns = matching_graph.gt_ccs nodes_fps = 0 nodes_fns = len(nodes_gt) topological_errors = { "n_gt": matching_graph.gt_ccs, "n_rec": matching_graph.rec_ccs, "splits": 0, "merges": 0, "fps": fps, "fns": fns } node_errors = { "vertices": 0, "edges": 0, "tps_rec": 0, "tps_gt": 0, "fps": nodes_fps, "fns": nodes_fns, "merges": 0, "splits": 0 } return node_errors, topological_errors