def get_plot_data(self, path, sort_by='eta'):
        """
        :param sort_by: 'noise_factor' or 'long_range_prob'
        """
        # TODO: use experiment path

        # Create dictionary:
        results_collected = {}

        for item in os.listdir(path):
            if os.path.isfile(os.path.join(path, item)):
                filename = item
                if not filename.endswith(".json") or filename.startswith("."):
                    continue
                # ID, sample, agglo_type, _ = filename.split("_")
                result_file = os.path.join(path, filename)
                with open(result_file, 'rb') as f:
                    file_dict = json.load(f)

                if sort_by == 'eta':
                    sort_key = file_dict["eta"]
                else:
                    raise ValueError
                method = file_dict["method_type"]
                if method == "GASP":
                    method_descriptor = file_dict["linkage_criteria"] + str(
                        file_dict["add_cannot_link_constraints"])
                elif method == "spectral":
                    method_descriptor = file_dict["spectral_method_name"]
                elif method == "multicut":
                    method_descriptor = file_dict["multicut_solver_type"]
                else:
                    raise ValueError
                new_results = {}
                new_results[method] = {}
                new_results[method][method_descriptor] = {}
                new_results[method][method_descriptor][sort_key] = {}
                new_results[method][method_descriptor][sort_key][
                    file_dict["ID"]] = file_dict

                try:
                    results_collected = recursive_dict_update(
                        new_results, results_collected)
                except KeyError:
                    continue
        return results_collected
Esempio n. 2
0
 def build_val_loader(self):
     scaling_factors = self.get("stack_scaling_factors_2")
     kwargs = recursive_dict_update(self.get('loaders/val'),
                                    deepcopy(self.get('loaders/general')))
     kwargs["volume_config"]["scaling_factors"] = scaling_factors
     return get_cremi_loader(kwargs)
                result_dict = json.load(f)
            edge_prob = result_dict["edge_prob"]
            non_link = result_dict["non_link"]

            new_results = {}

            new_results[agglo_type] = {}
            new_results[agglo_type][str(non_link)] = {}
            new_results[agglo_type][str(non_link)][edge_prob] = {}
            new_results[agglo_type][str(non_link)][edge_prob][ID] = {
                'energy': result_dict["energy"],
                'score': result_dict["score"],
                'score_WS': result_dict["score_WS"],
                'runtime': result_dict["runtime"]
            }
            results_collected = recursive_dict_update(new_results,
                                                      results_collected)

    ncols, nrows = 1, 1
    f, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(7, 7))

    for agglo_type in results_collected:
        if agglo_type != "max":
            continue
        for non_link in results_collected[agglo_type]:
            sub_dict = results_collected[agglo_type][non_link]
            probs = []
            values = []
            error_bars = []
            for edge_prob in sub_dict:
                multiple_values = []
                for ID in sub_dict[edge_prob]:
Esempio n. 4
0
 def build_val_loader(self):
     return get_cremi_loader(recursive_dict_update(self.get('loaders/val'), self.get('loaders/general')))
Esempio n. 5
0
 def build_val_loader(self):
     kwargs = recursive_dict_update(self.get('loaders/val'),
                                    deepcopy(self.get('loaders/general')))
     return get_cremi_loader(kwargs)
Esempio n. 6
0
 def build_train_loader(self):
     return get_cremi_loader(
         recursive_dict_update(self.get('loaders/train'),
                               deepcopy(self.get('loaders/general'))))
def get_segmentation(inverted_affinities, offsets, post_proc_config):
    n_threads = post_proc_config.pop('nb_threads')
    invert_affinities = post_proc_config.pop('invert_affinities', False)
    segm_pipeline_type = post_proc_config.pop('segm_pipeline_type', 'gen_HC')

    segmentation_pipeline = get_segmentation_pipeline(
        segm_pipeline_type,
        offsets,
        nb_threads=n_threads,
        invert_affinities=invert_affinities,
        return_fragments=False,
        **post_proc_config)

    if post_proc_config.get('use_final_agglomerater', False):
        final_agglomerater = GreedyEdgeContractionAgglomeraterFromSuperpixels(
            offsets,
            n_threads=n_threads,
            invert_affinities=invert_affinities,
            **post_proc_config['generalized_HC_kwargs']
            ['final_agglomeration_kwargs'])
    else:
        final_agglomerater = None

    post_proc_solver = BlockWise(
        segmentation_pipeline=segmentation_pipeline,
        offsets=offsets,
        final_agglomerater=final_agglomerater,
        blockwise=post_proc_config.get('blockwise', False),
        invert_affinities=invert_affinities,
        nb_threads=n_threads,
        return_fragments=False,
        blockwise_config=post_proc_config.get('blockwise_kwargs', {}))

    print("Starting prediction...")
    tick = time.time()
    outputs = post_proc_solver(affinities)
    comp_time = time.time() - tick
    print("Post-processing took {} s".format(comp_time))

    return outputs

    # print("Pred. sahpe: ", pred_segm.shape)
    # if not use_test_datasets:
    #     print("GT shape: ", gt.shape)
    #     print("Min. GT label: ", gt.min())

    # if post_proc_config.get('stacking_2D', False):
    #     print('2D stacking...')
    #     stacked_pred_segm = np.empty_like(pred_segm)
    #     max_label = 0
    #     for z in range(pred_segm.shape[0]):
    #         slc = vigra.analysis.labelImage(pred_segm[z].astype(np.uint32))
    #         stacked_pred_segm[z] = slc + max_label
    #         max_label += slc.max() + 1
    #     pred_segm = stacked_pred_segm

    # pred_segm_WS = None
    # if post_proc_config.get('thresh_segm_size', 0) != 0:
    grow = SizeThreshAndGrowWithWS(
        post_proc_config['thresh_segm_size'],
        offsets,
        hmap_kwargs=post_proc_config['prob_map_kwargs'],
        apply_WS_growing=True,
    )
    pred_segm_WS = grow(1 - inverted_affinities, pred_segm)

    # SAVING RESULTS:
    evals = cremi_score(gt,
                        pred_segm,
                        border_threshold=None,
                        return_all_scores=True)
    evals_WS = cremi_score(gt,
                           pred_segm_WS,
                           border_threshold=None,
                           return_all_scores=True)
    print("Scores achieved WS: ", evals_WS)
    print("Scores achieved: ", evals)

    ID = str(np.random.randint(10000000))

    extra_agglo = post_proc_config['generalized_HC_kwargs'][
        'agglomeration_kwargs']['extra_aggl_kwargs']
    agglo_type = extra_agglo['update_rule']
    non_link = extra_agglo['add_cannot_link_constraints']
    edge_prob = str(
        np.asscalar(post_proc_config['generalized_HC_kwargs']
                    ['probability_long_range_edges']))

    result_file = os.path.join(
        '/home/abailoni_local/',
        'generalized_GED_comparison_local_attraction.json')
    # result_dict = yaml2dict(result_file)
    # result_dict = {} if result_dict is None else result_dict
    if os.path.exists(result_file):
        with open(result_file, 'rb') as f:
            result_dict = json.load(f)
        os.remove(result_file)
    else:
        result_dict = {}

    new_results = {}
    new_results[agglo_type] = {}
    new_results[agglo_type][str(non_link)] = {}
    new_results[agglo_type][str(non_link)][edge_prob] = {}
    new_results[agglo_type][str(non_link)][edge_prob][ID] = {
        'energy': np.asscalar(MC_energy),
        'score': evals,
        'score_WS': evals_WS,
        'runtime': comp_time
    }

    file_name = "{}_{}_{}".format(ID, agglo_type, edge_prob)
    # file_path = os.path.join('/net/hciserver03/storage/abailoni/GEC_comparison', file_name)
    file_path = os.path.join(
        '/home/abailoni_local/GEC_comparison_local_attraction', file_name)

    result_dict = recursive_dict_update(new_results, result_dict)

    with open(result_file, 'w') as f:
        json.dump(result_dict, f, indent=4, sort_keys=True)
        # yaml.dump(result_dict, f)

    # Save some kwargs:
    vigra.writeHDF5(pred_segm.astype('uint32'), file_path, 'segm')
    vigra.writeHDF5(pred_segm_WS.astype('uint32'), file_path, 'segm_WS')
 def build_val_loader(self):
     kwargs = recursive_dict_update(self.get('loaders/val'), deepcopy(self.get('loaders/general')))
     datasets = MultiScaleDataset.from_config(kwargs)
     return DataLoader(datasets, **kwargs.get("loader_config", {}))