def get_result( setup, region, tracking_parameters, db_host, frames=None, sample=None, iteration='400000'): ''' Get the scores, statistics, and parameters for given setup, region, and parameters. Returns a dictionary containing the keys and values of the score object. tracking_parameters can be a dict or a TrackingParameters object''' if not sample: sample = get_sample_from_setup(setup) db_name = '_'.join(['linajea', sample, setup, region, iteration]) candidate_db = CandidateDatabase(db_name, db_host, 'r') if isinstance(tracking_parameters, dict): tracking_parameters = TrackingParameters(**tracking_parameters) parameters_id = candidate_db.get_parameters_id( tracking_parameters, fail_if_not_exists=True) result = candidate_db.get_score(parameters_id, frames=frames) return result
def test_unique_id_one_worker(self): db_name = 'test_linajea_db' db_host = 'localhost' db = CandidateDatabase(db_name, db_host, mode='w') for i in range(10): tp = linajea.tracking.TrackingParameters( **self.get_tracking_params()) tp.cost_appear = i _id = db.get_parameters_id(tp) self.assertEqual(_id, i + 1) self.delete_db(db_name, db_host)
def get_tgmm_results( region, db_host, sample, frames=None): if region is None: db_name = '_'.join(['linajea', sample, 'tgmm']) else: db_name = '_'.join(['linajea', sample, 'tgmm', region]) candidate_db = CandidateDatabase(db_name, db_host, 'r') results = candidate_db.get_scores(frames=frames) if results is None or len(results) == 0: return None all_results = pandas.DataFrame(results) return all_results
def test_write_and_get_score(self): db_name = 'test_linajea_database' db_host = 'localhost' ps = { "track_cost": 2.0, "weight_edge_score": 0.1, "weight_node_score": 1.0, "selection_constant": 0.0, "max_cell_move": 1.0, "block_size": [5, 100, 100, 100], "context": [2, 100, 100, 100], } parameters = linajea.tracking.TrackingParameters(**ps) db = CandidateDatabase(db_name, db_host) params_id = db.get_parameters_id(parameters) score = Report() score.gt_edges = 2 score.matched_edges = 2 score.fp_edges = 1 score.fn_edges = 0 db.write_score(params_id, score) score_dict = db.get_score(params_id) compare_dict = score.__dict__ compare_dict.update(db.get_parameters(params_id)) self.assertDictEqual(compare_dict, score_dict)
def test_database_creation(self): db_name = 'test_linajea_database' db_host = 'localhost' total_roi = Roi((0, 0, 0, 0), (10, 100, 100, 100)) candidate_db = CandidateDatabase(db_name, db_host, mode='w', total_roi=total_roi) sub_graph = candidate_db[total_roi] points = [] for i in range(5): points.append((i, { 'id': i, 't': i, 'z': i, 'y': i, 'x': i, })) edges = [] for i in range(4): edges.append((i + 1, i)) sub_graph.add_nodes_from(points) sub_graph.add_edges_from(edges) sub_graph.write_nodes() sub_graph.write_edges() logger.debug("Creating new database to read data") compare_db = CandidateDatabase(db_name, db_host, mode='r', total_roi=total_roi) compare_sub_graph = compare_db[total_roi] point_ids = [p[0] for p in points] self.assertCountEqual(compare_sub_graph.nodes, point_ids) self.assertCountEqual(compare_sub_graph.edges, edges) self.delete_db(db_name, db_host)
def get_results( setup, region, db_host, sample=None, iteration='400000', frames=None, filter_params=None): ''' Gets the scores, statistics, and parameters for all grid search configurations run for the given setup and region. Returns a pandas dataframe with one row per configuration.''' if not sample: sample = get_sample_from_setup(setup) db_name = '_'.join(['linajea', sample, setup, region, iteration]) candidate_db = CandidateDatabase(db_name, db_host, 'r') scores = candidate_db.get_scores(frames=frames, filters=filter_params) dataframe = pandas.DataFrame(scores) logger.debug("data types of dataframe columns: %s" % str(dataframe.dtypes)) if 'param_id' in dataframe: dataframe['_id'] = dataframe['param_id'] dataframe.set_index('param_id', inplace=True) return dataframe
def solve_in_block( db_host, db_name, parameters, block, parameters_id, solution_roi=None, cell_cycle_key=None): # Solution_roi is the total roi that you want a solution in # Limiting the block to the solution_roi allows you to solve # all the way to the edge, without worrying about reading # data from outside the solution roi # or paying the appear or disappear costs unnecessarily if len(parameters_id) == 1: step_name = 'solve_' + str(parameters_id[0]) else: _id = hash(frozenset(parameters_id)) step_name = 'solve_' + str(_id) logger.debug("Solving in block %s", block) if solution_roi: # Limit block to source_roi logger.debug("Block write roi: %s", block.write_roi) logger.debug("Solution roi: %s", solution_roi) read_roi = block.read_roi.intersect(solution_roi) write_roi = block.write_roi.intersect(solution_roi) else: read_roi = block.read_roi write_roi = block.write_roi logger.debug("Write roi: %s", str(write_roi)) graph_provider = CandidateDatabase( db_name, db_host, mode='r+') start_time = time.time() selected_keys = ['selected_' + str(pid) for pid in parameters_id] edge_attrs = selected_keys.copy() edge_attrs.extend(["prediction_distance", "distance"]) graph = graph_provider.get_graph( read_roi, edge_attrs=edge_attrs ) # remove dangling nodes and edges dangling_nodes = [ n for n, data in graph.nodes(data=True) if 't' not in data ] graph.remove_nodes_from(dangling_nodes) num_nodes = graph.number_of_nodes() num_edges = graph.number_of_edges() logger.info("Reading graph with %d nodes and %d edges took %s seconds" % (num_nodes, num_edges, time.time() - start_time)) if num_edges == 0: logger.info("No edges in roi %s. Skipping" % read_roi) write_done(block, step_name, db_name, db_host) return 0 frames = [read_roi.get_offset()[0], read_roi.get_offset()[0] + read_roi.get_shape()[0]] if isinstance(parameters[0], NMTrackingParameters): nm_track(graph, parameters, selected_keys, frames=frames) else: track(graph, parameters, selected_keys, frames=frames, cell_cycle_key=cell_cycle_key) start_time = time.time() graph.update_edge_attrs( write_roi, attributes=selected_keys) logger.info("Updating %d keys for %d edges took %s seconds" % (len(selected_keys), num_edges, time.time() - start_time)) write_done(block, step_name, db_name, db_host) return 0
def solve_blockwise( db_host, db_name, sample, parameters, # list of TrackingParameters num_workers=8, frames=None, limit_to_roi=None, from_scratch=False, data_dir='../01_data', cell_cycle_key=None, **kwargs): block_size = daisy.Coordinate(parameters[0].block_size) context = daisy.Coordinate(parameters[0].context) # block size and context must be the same for all parameters! for i in range(len(parameters)): assert list(block_size) == parameters[i].block_size,\ "%s not equal to %s" %\ (block_size, parameters[i].block_size) assert list(context) == parameters[i].context voxel_size, source_roi = get_source_roi(data_dir, sample) # determine parameters id from database graph_provider = CandidateDatabase( db_name, db_host) parameters_id = [graph_provider.get_parameters_id(p) for p in parameters] if from_scratch: for pid in parameters_id: graph_provider.set_parameters_id(pid) graph_provider.reset_selection() # limit to specific frames, if given if frames: logger.info("Solving in frames %s" % frames) begin, end = frames crop_roi = daisy.Roi( (begin, None, None, None), (end - begin, None, None, None)) source_roi = source_roi.intersect(crop_roi) # limit to roi, if given if limit_to_roi: logger.info("limiting to roi %s" % str(limit_to_roi)) source_roi = source_roi.intersect(limit_to_roi) block_write_roi = daisy.Roi( (0, 0, 0, 0), block_size) block_read_roi = block_write_roi.grow( context, context) total_roi = source_roi.grow( context, context) logger.info("Solving in %s", total_roi) param_names = ['solve_' + str(_id) for _id in parameters_id] if len(parameters_id) > 1: # check if set of parameters is already done step_name = 'solve_' + str(hash(frozenset(parameters_id))) if check_function_all_blocks(step_name, db_name, db_host): logger.info("Param set with name %s already completed. Exiting", step_name) return True else: step_name = 'solve_' + str(parameters_id[0]) # Check each individual parameter to see if it is done # if it is, remove it from the list done_indices = [] for _id, name in zip(parameters_id, param_names): if check_function_all_blocks(name, db_name, db_host): logger.info("Params with id %d already completed. Removing", _id) done_indices.append(parameters_id.index(_id)) for index in done_indices[::-1]: del parameters_id[index] del parameters[index] del param_names[index] logger.debug(parameters_id) if len(parameters_id) == 0: logger.info("All parameters in set already completed. Exiting") return True success = daisy.run_blockwise( total_roi, block_read_roi, block_write_roi, process_function=lambda b: solve_in_block( db_host, db_name, parameters, b, parameters_id, solution_roi=source_roi, cell_cycle_key=cell_cycle_key), # Note: in the case of a set of parameters, # we are assuming that none of the individual parameters are # half done and only checking the hash for each block check_function=lambda b: check_function( b, step_name, db_name, db_host), num_workers=num_workers, fit='overhang') if success: # write all done to individual parameters and set if len(param_names) > 1: write_done_all_blocks( step_name, db_name, db_host) for name in param_names: write_done_all_blocks( name, db_name, db_host) logger.info("Finished solving") return success
def test_get_selected_graph_and_reset_selection(self): db_name = 'test_linajea_database' db_host = 'localhost' total_roi = Roi((0, 0, 0, 0), (5, 10, 10, 10)) write_db = CandidateDatabase(db_name, db_host, mode='w', total_roi=total_roi) sub_graph = write_db[total_roi] points = [ (1, { 't': 0, 'z': 1, 'y': 2, 'x': 3 }), (2, { 't': 1, 'z': 1, 'y': 2, 'x': 3 }), (3, { 't': 2, 'z': 1, 'y': 2, 'x': 3 }), (4, { 't': 3, 'z': 1, 'y': 2, 'x': 3 }), (5, { 't': 2, 'z': 5, 'y': 2, 'x': 3 }), (6, { 't': 3, 'z': 5, 'y': 2, 'x': 3 }), ] edges = [ (2, 1, { 'selected_1': True }), (3, 2, { 'selected_1': True }), (4, 3, { 'selected_1': True }), (5, 2, { 'selected_1': False }), (6, 5, { 'selected_1': False }), ] sub_graph.add_nodes_from(points) sub_graph.add_edges_from(edges) sub_graph.write_nodes() sub_graph.write_edges() logger.debug("Creating new database to read data") read_db = CandidateDatabase(db_name, db_host, mode='r', parameters_id=1) selected_graph = read_db.get_selected_graph(total_roi) self.assertEqual(selected_graph.number_of_nodes(), 4) self.assertEqual(selected_graph.number_of_edges(), 3) read_db.reset_selection() unselected_graph = read_db.get_selected_graph(total_roi) self.assertEqual(unselected_graph.number_of_nodes(), 0) self.assertEqual(unselected_graph.number_of_edges(), 0)