def get_gt_tracks_for_roi(gt_db_name, mongo_url, roi): graph_provider = linajea.CandidateDatabase(gt_db_name, mongo_url) subgraph = graph_provider[roi] track_graph = linajea.tracking.TrackGraph(subgraph) tracks = track_graph.get_tracks() end_frame = roi.get_offset()[0] + roi.get_shape()[0] - 1 one_d_tracks = [] for track in tracks: for end_cell in track.get_cells_in_frame(end_frame): cell_positions = [] current_cell = end_cell while current_cell is not None: current_data = track.nodes[current_cell] cell_positions.append([current_data[dim] for dim in ['t', 'z', 'y', 'x']]) parent_edges = track.prev_edges(current_cell) if len(parent_edges) == 1: current_cell = parent_edges[0][1] elif len(parent_edges) == 0: current_cell = None else: print("Error: Cell has two parents! Exiting") return None one_d_tracks.append(cell_positions) print("Found %d tracks in roi %s" % (len(one_d_tracks), roi)) return one_d_tracks
def create_graph(self, cells, edges, roi): db = linajea.CandidateDatabase('test_eval', 'localhost') graph = db[roi] graph.add_nodes_from(cells) graph.add_edges_from(edges) tg = linajea.tracking.TrackGraph(graph_data=graph, frame_key='t') return tg
def get_track_from_node( node_id, node_frame, gt_db_name, mongo_url, num_frames_before, num_frames_after=0): graph_provider = linajea.CandidateDatabase(gt_db_name, mongo_url) roi = daisy.Roi((node_frame - num_frames_before + 1, 0, 0, 0), (num_frames_before + num_frames_after, 10e6, 10e6, 10e6)) subgraph = graph_provider[roi] track_graph = linajea.tracking.TrackGraph(subgraph, frame_key='t') tracks = track_graph.get_tracks() for track in tracks: if track.has_node(node_id): cell_positions = [] current_cell = node_id while True: current_data = track.nodes[current_cell] if 't' not in current_data: break cell_positions.append([current_data[dim] for dim in ['t', 'z', 'y', 'x']]) parent_edges = list(track.prev_edges(current_cell)) if len(parent_edges) == 1: current_cell = parent_edges[0][1] elif len(parent_edges) == 0: break else: print("Error: Cell has two parents! Exiting") return None return cell_positions print("Did not find track with node %d in roi %s" % (node_id, roi)) return None
def test_solver_node_close_to_edge(self): # x # 3| /-4 # 2| /--3 # 1| 0---1 # 0| \--2 # ------------------------------------ t # 0 1 2 cells = [ {'id': 0, 't': 0, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0}, {'id': 1, 't': 1, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0}, {'id': 2, 't': 2, 'z': 1, 'y': 1, 'x': 0, 'score': 2.0}, {'id': 3, 't': 2, 'z': 1, 'y': 1, 'x': 2, 'score': 2.0}, {'id': 4, 't': 2, 'z': 1, 'y': 1, 'x': 4, 'score': 2.0} ] edges = [ {'source': 1, 'target': 0, 'score': 1.0, 'prediction_distance': 0.0}, {'source': 2, 'target': 1, 'score': 1.0, 'prediction_distance': 1.0}, {'source': 3, 'target': 1, 'score': 1.0, 'prediction_distance': 1.0}, {'source': 4, 'target': 1, 'score': 1.0, 'prediction_distance': 2.0}, ] db_name = 'linajea_test_solver' db_host = 'localhost' graph_provider = linajea.CandidateDatabase( db_name, db_host) roi = daisy.Roi((0, 0, 0, 0), (5, 5, 5, 5)) graph = graph_provider[roi] ps = { "cost_appear": 1.0, "cost_disappear": 1.0, "cost_split": 0, "weight_prediction_distance_cost": 0.1, "weight_node_score": 1.0, "threshold_node_score": 0.0, "threshold_edge_score": 0.0, "max_cell_move": 1.0, "block_size": [5, 100, 100, 100], "context": [2, 100, 100, 100], } parameters = linajea.tracking.NMTrackingParameters(**ps) graph.add_nodes_from([(cell['id'], cell) for cell in cells]) graph.add_edges_from([(edge['source'], edge['target'], edge) for edge in edges]) track_graph = linajea.tracking.TrackGraph( graph, frame_key='t', roi=graph.roi) solver = linajea.tracking.NMSolver(track_graph, parameters, 'selected') for node, data in track_graph.nodes(data=True): close = solver._check_node_close_to_roi_edge(node, data, 1) if node in [2, 4]: close = not close self.assertFalse(close) self.delete_db(db_name, db_host)
def test_greedy_node_threshold(self): # x # 3| /-4 \ # 2| /--3---5 # 1| 0---1 # 0| \--2 # ------------------------------------ t # 0 1 2 3 cells = [ {'id': 0, 't': 0, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0}, {'id': 1, 't': 1, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0}, {'id': 2, 't': 2, 'z': 1, 'y': 1, 'x': 0, 'score': 2.0}, {'id': 3, 't': 2, 'z': 1, 'y': 1, 'x': 2, 'score': 1.0}, {'id': 4, 't': 2, 'z': 1, 'y': 1, 'x': 3, 'score': 2.0}, {'id': 5, 't': 3, 'z': 1, 'y': 1, 'x': 2, 'score': 2.0} ] edges = [ {'source': 1, 'target': 0, 'score': 1.0, 'distance': 0.0}, {'source': 2, 'target': 1, 'score': 1.0, 'distance': 1.0}, {'source': 3, 'target': 1, 'score': 1.0, 'distance': 1.0}, {'source': 4, 'target': 1, 'score': 1.0, 'distance': 2.0}, {'source': 5, 'target': 3, 'score': 1.0, 'distance': 0.0}, {'source': 5, 'target': 4, 'score': 1.0, 'distance': 1.0}, ] db_name = 'linajea_test_solver' db_host = 'localhost' graph_provider = linajea.CandidateDatabase( db_name, db_host) roi = daisy.Roi((0, 0, 0, 0), (4, 5, 5, 5)) graph = graph_provider[roi] graph.add_nodes_from([(cell['id'], cell) for cell in cells]) graph.add_edges_from([(edge['source'], edge['target'], edge) for edge in edges]) linajea.tracking.greedy_track( graph, selected_key='selected', metric='distance', frame_key='t', node_threshold=1.5) selected_edges = [] for u, v, data in graph.edges(data=True): if data['selected']: selected_edges.append((u, v)) expected_result = [ (1, 0), (2, 1), (4, 1), (5, 4) ] self.assertCountEqual(selected_edges, expected_result) self.delete_db(db_name, db_host)
def read_nodes_and_edges(self, db_name, frames=None, nodes_key=None, edges_key=None, key=None, filter_unattached=True): db = linajea.CandidateDatabase(db_name, self.mongo_url) if frames is None: frames = [0, 1e10] roi = daisy.Roi((frames[0], 0, 0, 0), (frames[1] - frames[0], 1e10, 1e10, 1e10)) if nodes_key is None: nodes = db.read_nodes(roi) else: nodes = db.read_nodes(roi, attr_filter={nodes_key: True}) node_ids = [node['id'] for node in nodes] logger.debug("Found %d nodes" % len(node_ids)) if edges_key is None and key is not None: edges_key = key if edges_key is None: edges = db.read_edges(roi, nodes=nodes) else: edges = db.read_edges(roi, nodes=nodes, attr_filter={edges_key: True}) if filter_unattached: logger.debug("Filtering cells") filtered_cell_ids = set([edge['source'] for edge in edges] + [edge['target'] for edge in edges]) filtered_cells = [ cell for cell in nodes if cell['id'] in filtered_cell_ids ] nodes = filtered_cells node_ids = filtered_cell_ids logger.debug("Done filtering cells") logger.debug("Adjusting ids") target_min_id = 0 actual_min_id = min(node_ids) diff = actual_min_id - target_min_id logger.debug("Subtracting {} from all cell ids".format(diff)) for node in nodes: node['name'] = node['id'] node['id'] -= diff for edge in edges: edge['source'] -= diff edge['target'] -= diff return nodes, edges
def test_solver_cell_cycle2(self): '''x 3| /-4 2| /--3---5 1| 0---1 0| \\--2 ------------------------------------ t 0 1 2 3 Should select 0, 1, 3, 5 due to vgg predicting continuation ''' cells = [{ 'id': 0, 't': 0, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0, 'vgg_score': [0, 0, 1] }, { 'id': 1, 't': 1, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0, 'vgg_score': [0, 0, 1] }, { 'id': 2, 't': 2, 'z': 1, 'y': 1, 'x': 0, 'score': 2.0, 'vgg_score': [0, 0, 1] }, { 'id': 3, 't': 2, 'z': 1, 'y': 1, 'x': 2, 'score': 2.0, 'vgg_score': [0, 0, 1] }, { 'id': 4, 't': 2, 'z': 1, 'y': 1, 'x': 3, 'score': 2.0, 'vgg_score': [0, 0, 1] }, { 'id': 5, 't': 3, 'z': 1, 'y': 1, 'x': 2, 'score': 2.0, 'vgg_score': [0, 0, 1] }] edges = [ { 'source': 1, 'target': 0, 'score': 1.0, 'prediction_distance': 0.0 }, { 'source': 2, 'target': 1, 'score': 1.0, 'prediction_distance': 1.0 }, { 'source': 3, 'target': 1, 'score': 1.0, 'prediction_distance': 1.0 }, { 'source': 4, 'target': 1, 'score': 1.0, 'prediction_distance': 2.0 }, { 'source': 5, 'target': 3, 'score': 1.0, 'prediction_distance': 0.0 }, ] db_name = 'linajea_test_solver' db_host = 'localhost' graph_provider = linajea.CandidateDatabase(db_name, db_host) roi = daisy.Roi((0, 0, 0, 0), (4, 5, 5, 5)) graph = graph_provider[roi] ps = { "track_cost": 4.0, "weight_edge_score": 0.1, "weight_node_score": -0.1, "selection_constant": 0.0, "weight_division": -0.1, "weight_child": -0.1, "weight_continuation": -0.1, "division_constant": 1, "max_cell_move": 0.0, "block_size": [5, 100, 100, 100], "context": [2, 100, 100, 100], } parameters = linajea.tracking.TrackingParameters(**ps) graph.add_nodes_from([(cell['id'], cell) for cell in cells]) graph.add_edges_from([(edge['source'], edge['target'], edge) for edge in edges]) linajea.tracking.track(graph, parameters, frame_key='t', selected_key='selected', cell_cycle_key="vgg_score") selected_edges = [] for u, v, data in graph.edges(data=True): if data['selected']: selected_edges.append((u, v)) expected_result = [(1, 0), (3, 1), (5, 3)] self.assertCountEqual(selected_edges, expected_result) self.delete_db(db_name, db_host)
def test_solver_multiple_configs(self): # x # 3| /-4 # 2| /--3---5 # 1| 0---1 # 0| \--2 # ------------------------------------ t # 0 1 2 3 cells = [{ 'id': 0, 't': 0, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0 }, { 'id': 1, 't': 1, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0 }, { 'id': 2, 't': 2, 'z': 1, 'y': 1, 'x': 0, 'score': 2.0 }, { 'id': 3, 't': 2, 'z': 1, 'y': 1, 'x': 2, 'score': 2.0 }, { 'id': 4, 't': 2, 'z': 1, 'y': 1, 'x': 3, 'score': 2.0 }, { 'id': 5, 't': 3, 'z': 1, 'y': 1, 'x': 2, 'score': 2.0 }] edges = [ { 'source': 1, 'target': 0, 'score': 1.0, 'prediction_distance': 0.0 }, { 'source': 2, 'target': 1, 'score': 1.0, 'prediction_distance': 1.0 }, { 'source': 3, 'target': 1, 'score': 1.0, 'prediction_distance': 1.0 }, { 'source': 4, 'target': 1, 'score': 1.0, 'prediction_distance': 2.0 }, { 'source': 5, 'target': 3, 'score': 1.0, 'prediction_distance': 0.0 }, ] db_name = 'linajea_test_solver' db_host = 'localhost' graph_provider = linajea.CandidateDatabase(db_name, db_host) roi = daisy.Roi((0, 0, 0, 0), (4, 5, 5, 5)) graph = graph_provider[roi] ps1 = { "track_cost": 4.0, "weight_edge_score": 0.1, "weight_node_score": -0.1, "selection_constant": -1.0, "max_cell_move": 0.0, "block_size": [5, 100, 100, 100], "context": [2, 100, 100, 100], } ps2 = { # Making all the values smaller increases the # relative cost of division "track_cost": 1.0, "weight_edge_score": 0.01, "weight_node_score": -0.01, "selection_constant": -0.1, "max_cell_move": 0.0, "block_size": [5, 100, 100, 100], "context": [2, 100, 100, 100], } parameters = [ linajea.tracking.TrackingParameters(**ps1), linajea.tracking.TrackingParameters(**ps2) ] keys = ['selected_1', 'selected_2'] graph.add_nodes_from([(cell['id'], cell) for cell in cells]) graph.add_edges_from([(edge['source'], edge['target'], edge) for edge in edges]) linajea.tracking.track(graph, parameters, frame_key='t', selected_key=keys) selected_edges_1 = [] selected_edges_2 = [] for u, v, data in graph.edges(data=True): if data['selected_1']: selected_edges_1.append((u, v)) if data['selected_2']: selected_edges_2.append((u, v)) expected_result_1 = [(1, 0), (2, 1), (3, 1), (5, 3)] expected_result_2 = [(1, 0), (3, 1), (5, 3)] self.assertCountEqual(selected_edges_1, expected_result_1) self.assertCountEqual(selected_edges_2, expected_result_2) self.delete_db(db_name, db_host)
def extract_edges_in_block(db_name, db_host, edge_move_threshold, block, use_pv_distance=False): logger.info("Finding edges in %s, reading from %s", block.write_roi, block.read_roi) start = time.time() graph_provider = linajea.CandidateDatabase(db_name, db_host, mode='r+') graph = graph_provider[block.read_roi] if graph.number_of_nodes() == 0: logger.info("No cells in roi %s. Skipping", block.read_roi) write_done(block, 'extract_edges', db_name, db_host) return 0 logger.info("Read %d cells in %.3fs", graph.number_of_nodes(), time.time() - start) start = time.time() t_begin = block.write_roi.get_begin()[0] t_end = block.write_roi.get_end()[0] cells_by_t = { t: [(cell, np.array([data[d] for d in ['z', 'y', 'x']]), np.array(data['parent_vector'])) for cell, data in graph.nodes(data=True) if 't' in data and data['t'] == t] for t in range(t_begin - 1, t_end) } for t in range(t_begin, t_end): pre = t - 1 nex = t logger.debug( "Finding edges between cells in frames %d and %d " "(%d and %d cells)", pre, nex, len(cells_by_t[pre]), len(cells_by_t[nex])) if len(cells_by_t[pre]) == 0 or len(cells_by_t[nex]) == 0: logger.debug("There are no edges between these frames, skipping") continue # prepare KD tree for fast lookup of 'pre' cells logger.debug("Preparing KD tree...") all_pre_cells = cells_by_t[pre] kd_data = [cell[1] for cell in all_pre_cells] pre_kd_tree = cKDTree(kd_data) for i, nex_cell in enumerate(cells_by_t[nex]): nex_cell_id = nex_cell[0] nex_cell_center = nex_cell[1] nex_parent_center = nex_cell_center + nex_cell[2] if use_pv_distance: pre_cells_indices = pre_kd_tree.query_ball_point( nex_parent_center, edge_move_threshold) else: pre_cells_indices = pre_kd_tree.query_ball_point( nex_cell_center, edge_move_threshold) pre_cells = [all_pre_cells[i] for i in pre_cells_indices] logger.debug("Linking to %d cells in previous frame", len(pre_cells)) if len(pre_cells) == 0: continue for pre_cell in pre_cells: pre_cell_id = pre_cell[0] pre_cell_center = pre_cell[1] moved = (pre_cell_center - nex_cell_center) distance = np.linalg.norm(moved) prediction_offset = (pre_cell_center - nex_parent_center) prediction_distance = np.linalg.norm(prediction_offset) graph.add_edge(nex_cell_id, pre_cell_id, distance=distance, prediction_distance=prediction_distance) logger.info("Found %d edges", graph.number_of_edges()) logger.info("Extracted edges in %.3fs", time.time() - start) start = time.time() graph.write_edges(block.write_roi) logger.info("Wrote edges in %.3fs", time.time() - start) write_done(block, 'extract_edges', db_name, db_host) return 0
def test_solver_basic(self): # x # 3| /-4 # 2| /--3---5 # 1| 0---1 # 0| \--2 # ------------------------------------ t # 0 1 2 3 cells = [ {'id': 0, 't': 0, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0}, {'id': 1, 't': 1, 'z': 1, 'y': 1, 'x': 1, 'score': 2.0}, {'id': 2, 't': 2, 'z': 1, 'y': 1, 'x': 0, 'score': 2.0}, {'id': 3, 't': 2, 'z': 1, 'y': 1, 'x': 2, 'score': 2.0}, {'id': 4, 't': 2, 'z': 1, 'y': 1, 'x': 3, 'score': 2.0}, {'id': 5, 't': 3, 'z': 1, 'y': 1, 'x': 2, 'score': 2.0} ] edges = [ {'source': 1, 'target': 0, 'score': 1.0, 'prediction_distance': 0.0}, {'source': 2, 'target': 1, 'score': 1.0, 'prediction_distance': 1.0}, {'source': 3, 'target': 1, 'score': 1.0, 'prediction_distance': 1.0}, {'source': 4, 'target': 1, 'score': 1.0, 'prediction_distance': 2.0}, {'source': 5, 'target': 3, 'score': 1.0, 'prediction_distance': 0.0}, ] db_name = 'linajea_test_solver' db_host = 'localhost' graph_provider = linajea.CandidateDatabase( db_name, db_host) roi = daisy.Roi((0, 0, 0, 0), (4, 5, 5, 5)) graph = graph_provider[roi] ps = { "cost_appear": 2.0, "cost_disappear": 2.0, "cost_split": 0, "weight_prediction_distance_cost": 0.1, "weight_node_score": 1.0, "threshold_node_score": 0.0, "threshold_edge_score": 2.0, "max_cell_move": 0.0, "block_size": [5, 100, 100, 100], "context": [2, 100, 100, 100], } parameters = linajea.tracking.NMTrackingParameters(**ps) graph.add_nodes_from([(cell['id'], cell) for cell in cells]) graph.add_edges_from([(edge['source'], edge['target'], edge) for edge in edges]) linajea.tracking.nm_track( graph, parameters, frame_key='t', selected_key='selected') selected_edges = [] for u, v, data in graph.edges(data=True): if data['selected']: selected_edges.append((u, v)) expected_result = [ (1, 0), (2, 1), (3, 1), (5, 3) ] self.assertCountEqual(selected_edges, expected_result) self.delete_db(db_name, db_host)
def read_data(self, data): candidate_db_name = data['db_name'] start_frame, end_frame = data['frames'] matching_threshold = data.get('matching_threshold', 20) gt_db_name = data['gt_db_name'] assert end_frame > start_frame roi = Roi((start_frame, 0, 0, 0), (end_frame - start_frame, 1e10, 1e10, 1e10)) if 'parameters_id' in data: try: int(data['parameters_id']) selected_key = 'selected_' + str(data['parameters_id']) except: selected_key = data['parameters_id'] else: selected_key = None db = linajea.CandidateDatabase( candidate_db_name, self.db_host) db.selected_key = selected_key gt_db = linajea.CandidateDatabase(gt_db_name, self.db_host) print("Reading GT cells and edges in %s" % roi) gt_subgraph = gt_db[roi] gt_graph = linajea.tracking.TrackGraph(gt_subgraph, frame_key='t') gt_tracks = list(gt_graph.get_tracks()) print("Found %d GT tracks" % len(gt_tracks)) # tracks_to_xml(gt_cells, gt_tracks, 'linajea_gt.xml') print("Reading cells and edges in %s" % roi) subgraph = db.get_selected_graph(roi) graph = linajea.tracking.TrackGraph(subgraph, frame_key='t') tracks = list(graph.get_tracks()) print("Found %d tracks" % len(tracks)) if len(graph.nodes) == 0 or len(gt_graph.nodes) == 0: logger.info("Didn't find gt or reconstruction - returning") return [], [] m = linajea.evaluation.match_edges( gt_graph, graph, matching_threshold=matching_threshold) (edges_x, edges_y, edge_matches, edge_fps) = m matched_rec_tracks = [] for track in tracks: for _, edge_index in edge_matches: edge = edges_y[edge_index] if track.has_edge(edge[0], edge[1]): matched_rec_tracks.append(track) break logger.debug("found %d matched rec tracks" % len(matched_rec_tracks)) logger.info("Adding %d gt tracks" % len(gt_tracks)) track_id = 0 cells = [] tracks = [] for track in gt_tracks: result = self.add_track(track, track_id, group=0) print(result[0]) if result is None or len(result[0]) == 0: continue track_cells, track = result cells += track_cells tracks.append(track) track_id += 1 logger.info("Adding %d matched rec tracks" % len(matched_rec_tracks)) for track in matched_rec_tracks: result = self.add_track(track, track_id, group=1) if result is None: continue track_cells, track = result cells += track_cells tracks.append(track) track_id += 1 return cells, tracks
def get_node_ids_in_frame(gt_db_name, mongo_url, frame): graph_provider = linajea.CandidateDatabase(gt_db_name, mongo_url) roi = daisy.Roi((frame, 0, 0, 0), (1, 10e6, 10e6, 10e6)) nodes = graph_provider.read_nodes(roi) node_ids = [node['id'] for node in nodes] return node_ids