def segment_rg2sp_graphcut(slic,
                           seg,
                           centers,
                           labels_fg_prob,
                           path_model,
                           coef_shape,
                           coef_pairwise=5,
                           allow_obj_swap=True,
                           prob_label_trans=(0.1, 0.03),
                           dict_thresholds=RG2SP_THRESHOLDS,
                           debug_export=''):
    """ wrapper for region growing method with some debug exporting """
    if os.path.splitext(path_model)[-1] == '.npz':
        shape_model = np.load(path_model)
    else:
        shape_model = pickle.load(open(path_model, 'rb'))
    dict_debug = dict() if os.path.isdir(debug_export) else None

    labels_gc = seg_rg.region_growing_shape_slic_graphcut(
        seg,
        slic,
        centers, (shape_model['mix_model'], shape_model['cdfs']),
        shape_model['name'],
        labels_fg_prob,
        coef_shape,
        coef_pairwise,
        prob_label_trans,
        optim_global=True,
        allow_obj_swap=allow_obj_swap,
        dict_thresholds=dict_thresholds,
        nb_iter=250,
        dict_debug_history=dict_debug)

    if dict_debug is not None:
        nb_iter = len(dict_debug['energy'])
        for i in range(nb_iter):
            fig = tl_visu.figure_rg2sp_debug_complete(seg, slic, dict_debug, i)
            fig.savefig(os.path.join(debug_export, 'iter_%03d' % i))
            plt.close(fig)

    segm_obj = labels_gc[slic]
    return segm_obj, centers, None
Esempio n. 2
0
    def test_region_growing_graphcut(self, name='insitu7545'):
        """    """
        if not os.path.exists(PATH_PKL_MODEL):
            self.test_shape_modeling()

        # file_model = pickle.load(open(PATH_PKL_MODEL, 'r'))
        npz_file = np.load(PATH_PKL_MODEL)
        file_model = dict(npz_file[npz_file.files[0]].tolist())
        logging.info('loaded model: %s', repr(file_model.keys()))
        list_mean_cdf = file_model['cdfs']
        model = file_model['mix_model']

        img, _ = tl_data.load_image_2d(os.path.join(PATH_IMAGE, name + '.jpg'))
        seg, _ = tl_data.load_image_2d(os.path.join(PATH_SEGM, name + '.png'))
        annot, _ = tl_data.load_image_2d(
            os.path.join(PATH_ANNOT, name + '.png'))
        centers = pd.read_csv(os.path.join(PATH_CENTRE, name + '.csv'),
                              index_col=0).values
        centers[:, [0, 1]] = centers[:, [1, 0]]

        slic = seg_spx.segment_slic_img2d(img, sp_size=25, rltv_compact=0.3)

        dict_debug = {}
        labels_gc = seg_rg.region_growing_shape_slic_graphcut(
            seg,
            slic,
            centers, (model, list_mean_cdf),
            'set_cdfs',
            LABELS_FG_PROB,
            coef_shape=5.,
            coef_pairwise=15.,
            prob_label_trans=[0.1, 0.03],
            optim_global=False,
            nb_iter=65,
            allow_obj_swap=False,
            dict_thresholds=DEFAULT_RG2SP_THRESHOLDS,
            dict_debug_history=dict_debug)

        segm_obj = labels_gc[slic]
        logging.info('debug: %s', repr(dict_debug.keys()))

        for i in np.linspace(0, len(dict_debug['energy']) - 1, 5):
            fig = tl_visu.figure_rg2sp_debug_complete(seg,
                                                      slic,
                                                      dict_debug,
                                                      int(i),
                                                      max_size=5)
            fig_name = 'RG2Sp_graph-cut_%s_debug-%03d.pdf' % (name, i)
            fig.savefig(os.path.join(PATH_OUTPUT, fig_name),
                        bbox_inches='tight',
                        pad_inches=0)
            plt.close(fig)

        score = adjusted_rand_score(annot.ravel(), segm_obj.ravel())
        self.assertGreaterEqual(score, 0.5)

        expert_segm(name,
                    img,
                    seg,
                    segm_obj,
                    annot,
                    str_type='RG2Sp_graph-cut')