Пример #1
0
def get_result(
        setup,
        region,
        tracking_parameters,
        db_host,
        frames=None,
        sample=None,
        iteration='400000'):
    ''' Get the scores, statistics, and parameters for given
    setup, region, and parameters.
    Returns a dictionary containing the keys and values of the score
    object.

    tracking_parameters can be a dict or a TrackingParameters object'''
    if not sample:
        sample = get_sample_from_setup(setup)
    db_name = '_'.join(['linajea', sample, setup, region, iteration])
    candidate_db = CandidateDatabase(db_name, db_host, 'r')
    if isinstance(tracking_parameters, dict):
        tracking_parameters = TrackingParameters(**tracking_parameters)
    parameters_id = candidate_db.get_parameters_id(
            tracking_parameters,
            fail_if_not_exists=True)
    result = candidate_db.get_score(parameters_id, frames=frames)
    return result
Пример #2
0
 def test_unique_id_one_worker(self):
     db_name = 'test_linajea_db'
     db_host = 'localhost'
     db = CandidateDatabase(db_name, db_host, mode='w')
     for i in range(10):
         tp = linajea.tracking.TrackingParameters(
             **self.get_tracking_params())
         tp.cost_appear = i
         _id = db.get_parameters_id(tp)
         self.assertEqual(_id, i + 1)
     self.delete_db(db_name, db_host)
Пример #3
0
def get_tgmm_results(
        region,
        db_host,
        sample,
        frames=None):
    if region is None:
        db_name = '_'.join(['linajea', sample, 'tgmm'])
    else:
        db_name = '_'.join(['linajea', sample, 'tgmm', region])
    candidate_db = CandidateDatabase(db_name, db_host, 'r')
    results = candidate_db.get_scores(frames=frames)
    if results is None or len(results) == 0:
        return None
    all_results = pandas.DataFrame(results)
    return all_results
Пример #4
0
    def test_write_and_get_score(self):
        db_name = 'test_linajea_database'
        db_host = 'localhost'
        ps = {
            "track_cost": 2.0,
            "weight_edge_score": 0.1,
            "weight_node_score": 1.0,
            "selection_constant": 0.0,
            "max_cell_move": 1.0,
            "block_size": [5, 100, 100, 100],
            "context": [2, 100, 100, 100],
        }
        parameters = linajea.tracking.TrackingParameters(**ps)

        db = CandidateDatabase(db_name, db_host)
        params_id = db.get_parameters_id(parameters)

        score = Report()
        score.gt_edges = 2
        score.matched_edges = 2
        score.fp_edges = 1
        score.fn_edges = 0
        db.write_score(params_id, score)
        score_dict = db.get_score(params_id)

        compare_dict = score.__dict__
        compare_dict.update(db.get_parameters(params_id))
        self.assertDictEqual(compare_dict, score_dict)
Пример #5
0
    def test_database_creation(self):
        db_name = 'test_linajea_database'
        db_host = 'localhost'
        total_roi = Roi((0, 0, 0, 0), (10, 100, 100, 100))

        candidate_db = CandidateDatabase(db_name,
                                         db_host,
                                         mode='w',
                                         total_roi=total_roi)

        sub_graph = candidate_db[total_roi]
        points = []
        for i in range(5):
            points.append((i, {
                'id': i,
                't': i,
                'z': i,
                'y': i,
                'x': i,
            }))
        edges = []
        for i in range(4):
            edges.append((i + 1, i))
        sub_graph.add_nodes_from(points)
        sub_graph.add_edges_from(edges)
        sub_graph.write_nodes()
        sub_graph.write_edges()

        logger.debug("Creating new database to read data")
        compare_db = CandidateDatabase(db_name,
                                       db_host,
                                       mode='r',
                                       total_roi=total_roi)

        compare_sub_graph = compare_db[total_roi]

        point_ids = [p[0] for p in points]
        self.assertCountEqual(compare_sub_graph.nodes, point_ids)
        self.assertCountEqual(compare_sub_graph.edges, edges)
        self.delete_db(db_name, db_host)
Пример #6
0
def get_results(
        setup,
        region,
        db_host,
        sample=None,
        iteration='400000',
        frames=None,
        filter_params=None):
    ''' Gets the scores, statistics, and parameters for all
    grid search configurations run for the given setup and region.
    Returns a pandas dataframe with one row per configuration.'''
    if not sample:
        sample = get_sample_from_setup(setup)
    db_name = '_'.join(['linajea', sample, setup, region, iteration])
    candidate_db = CandidateDatabase(db_name, db_host, 'r')
    scores = candidate_db.get_scores(frames=frames, filters=filter_params)
    dataframe = pandas.DataFrame(scores)
    logger.debug("data types of dataframe columns: %s"
                 % str(dataframe.dtypes))
    if 'param_id' in dataframe:
        dataframe['_id'] = dataframe['param_id']
        dataframe.set_index('param_id', inplace=True)
    return dataframe
Пример #7
0
def solve_in_block(
        db_host,
        db_name,
        parameters,
        block,
        parameters_id,
        solution_roi=None,
        cell_cycle_key=None):
    # Solution_roi is the total roi that you want a solution in
    # Limiting the block to the solution_roi allows you to solve
    # all the way to the edge, without worrying about reading
    # data from outside the solution roi
    # or paying the appear or disappear costs unnecessarily

    if len(parameters_id) == 1:
        step_name = 'solve_' + str(parameters_id[0])
    else:
        _id = hash(frozenset(parameters_id))
        step_name = 'solve_' + str(_id)
    logger.debug("Solving in block %s", block)

    if solution_roi:
        # Limit block to source_roi
        logger.debug("Block write roi: %s", block.write_roi)
        logger.debug("Solution roi: %s", solution_roi)
        read_roi = block.read_roi.intersect(solution_roi)
        write_roi = block.write_roi.intersect(solution_roi)
    else:
        read_roi = block.read_roi
        write_roi = block.write_roi

    logger.debug("Write roi: %s", str(write_roi))

    graph_provider = CandidateDatabase(
        db_name,
        db_host,
        mode='r+')
    start_time = time.time()
    selected_keys = ['selected_' + str(pid) for pid in parameters_id]
    edge_attrs = selected_keys.copy()
    edge_attrs.extend(["prediction_distance", "distance"])
    graph = graph_provider.get_graph(
            read_roi,
            edge_attrs=edge_attrs
            )

    # remove dangling nodes and edges
    dangling_nodes = [
        n
        for n, data in graph.nodes(data=True)
        if 't' not in data
    ]
    graph.remove_nodes_from(dangling_nodes)

    num_nodes = graph.number_of_nodes()
    num_edges = graph.number_of_edges()
    logger.info("Reading graph with %d nodes and %d edges took %s seconds"
                % (num_nodes, num_edges, time.time() - start_time))

    if num_edges == 0:
        logger.info("No edges in roi %s. Skipping"
                    % read_roi)
        write_done(block, step_name, db_name, db_host)
        return 0

    frames = [read_roi.get_offset()[0],
              read_roi.get_offset()[0] + read_roi.get_shape()[0]]
    if isinstance(parameters[0], NMTrackingParameters):
        nm_track(graph, parameters, selected_keys, frames=frames)
    else:
        track(graph, parameters, selected_keys, frames=frames,
              cell_cycle_key=cell_cycle_key)
    start_time = time.time()
    graph.update_edge_attrs(
            write_roi,
            attributes=selected_keys)
    logger.info("Updating %d keys for %d edges took %s seconds"
                % (len(selected_keys),
                   num_edges,
                   time.time() - start_time))
    write_done(block, step_name, db_name, db_host)
    return 0
Пример #8
0
def solve_blockwise(
        db_host,
        db_name,
        sample,
        parameters,  # list of TrackingParameters
        num_workers=8,
        frames=None,
        limit_to_roi=None,
        from_scratch=False,
        data_dir='../01_data',
        cell_cycle_key=None,
        **kwargs):

    block_size = daisy.Coordinate(parameters[0].block_size)
    context = daisy.Coordinate(parameters[0].context)
    # block size and context must be the same for all parameters!
    for i in range(len(parameters)):
        assert list(block_size) == parameters[i].block_size,\
                "%s not equal to %s" %\
                (block_size, parameters[i].block_size)
        assert list(context) == parameters[i].context

    voxel_size, source_roi = get_source_roi(data_dir, sample)

    # determine parameters id from database
    graph_provider = CandidateDatabase(
        db_name,
        db_host)
    parameters_id = [graph_provider.get_parameters_id(p) for p in parameters]

    if from_scratch:
        for pid in parameters_id:
            graph_provider.set_parameters_id(pid)
            graph_provider.reset_selection()

    # limit to specific frames, if given
    if frames:
        logger.info("Solving in frames %s" % frames)
        begin, end = frames
        crop_roi = daisy.Roi(
            (begin, None, None, None),
            (end - begin, None, None, None))
        source_roi = source_roi.intersect(crop_roi)
    # limit to roi, if given
    if limit_to_roi:
        logger.info("limiting to roi %s" % str(limit_to_roi))
        source_roi = source_roi.intersect(limit_to_roi)

    block_write_roi = daisy.Roi(
        (0, 0, 0, 0),
        block_size)
    block_read_roi = block_write_roi.grow(
        context,
        context)
    total_roi = source_roi.grow(
        context,
        context)

    logger.info("Solving in %s", total_roi)

    param_names = ['solve_' + str(_id) for _id in parameters_id]
    if len(parameters_id) > 1:
        # check if set of parameters is already done
        step_name = 'solve_' + str(hash(frozenset(parameters_id)))
        if check_function_all_blocks(step_name, db_name, db_host):
            logger.info("Param set with name %s already completed. Exiting",
                        step_name)
            return True
    else:
        step_name = 'solve_' + str(parameters_id[0])
    # Check each individual parameter to see if it is done
    # if it is, remove it from the list
    done_indices = []
    for _id, name in zip(parameters_id, param_names):
        if check_function_all_blocks(name, db_name, db_host):
            logger.info("Params with id %d already completed. Removing", _id)
            done_indices.append(parameters_id.index(_id))
    for index in done_indices[::-1]:
        del parameters_id[index]
        del parameters[index]
        del param_names[index]
    logger.debug(parameters_id)
    if len(parameters_id) == 0:
        logger.info("All parameters in set already completed. Exiting")
        return True

    success = daisy.run_blockwise(
        total_roi,
        block_read_roi,
        block_write_roi,
        process_function=lambda b: solve_in_block(
            db_host,
            db_name,
            parameters,
            b,
            parameters_id,
            solution_roi=source_roi,
            cell_cycle_key=cell_cycle_key),
        # Note: in the case of a set of parameters,
        # we are assuming that none of the individual parameters are
        # half done and only checking the hash for each block
        check_function=lambda b: check_function(
            b,
            step_name,
            db_name,
            db_host),
        num_workers=num_workers,
        fit='overhang')
    if success:
        # write all done to individual parameters and set
        if len(param_names) > 1:
            write_done_all_blocks(
                step_name,
                db_name,
                db_host)
        for name in param_names:
            write_done_all_blocks(
                name,
                db_name,
                db_host)
    logger.info("Finished solving")
    return success
Пример #9
0
    def test_get_selected_graph_and_reset_selection(self):
        db_name = 'test_linajea_database'
        db_host = 'localhost'
        total_roi = Roi((0, 0, 0, 0), (5, 10, 10, 10))

        write_db = CandidateDatabase(db_name,
                                     db_host,
                                     mode='w',
                                     total_roi=total_roi)

        sub_graph = write_db[total_roi]
        points = [
            (1, {
                't': 0,
                'z': 1,
                'y': 2,
                'x': 3
            }),
            (2, {
                't': 1,
                'z': 1,
                'y': 2,
                'x': 3
            }),
            (3, {
                't': 2,
                'z': 1,
                'y': 2,
                'x': 3
            }),
            (4, {
                't': 3,
                'z': 1,
                'y': 2,
                'x': 3
            }),
            (5, {
                't': 2,
                'z': 5,
                'y': 2,
                'x': 3
            }),
            (6, {
                't': 3,
                'z': 5,
                'y': 2,
                'x': 3
            }),
        ]
        edges = [
            (2, 1, {
                'selected_1': True
            }),
            (3, 2, {
                'selected_1': True
            }),
            (4, 3, {
                'selected_1': True
            }),
            (5, 2, {
                'selected_1': False
            }),
            (6, 5, {
                'selected_1': False
            }),
        ]
        sub_graph.add_nodes_from(points)
        sub_graph.add_edges_from(edges)
        sub_graph.write_nodes()
        sub_graph.write_edges()

        logger.debug("Creating new database to read data")
        read_db = CandidateDatabase(db_name,
                                    db_host,
                                    mode='r',
                                    parameters_id=1)
        selected_graph = read_db.get_selected_graph(total_roi)
        self.assertEqual(selected_graph.number_of_nodes(), 4)
        self.assertEqual(selected_graph.number_of_edges(), 3)

        read_db.reset_selection()
        unselected_graph = read_db.get_selected_graph(total_roi)
        self.assertEqual(unselected_graph.number_of_nodes(), 0)
        self.assertEqual(unselected_graph.number_of_edges(), 0)