Example #1
0
def pipe_gray3d_slic_features_gmm_graphcut(image,
                                           nb_classes=4,
                                           spacing=(12, 1, 1),
                                           sp_size=15,
                                           sp_regul=0.2,
                                           gc_regul=0.1,
                                           dict_features=FTS_SET_SIMPLE):
    """ complete pipe-line for segmentation using superpixels, extracting features
    and graphCut segmentation

    :param ndarray img: input RGB image
    :param int sp_size: initial size of a superpixel(meaning edge lenght)
    :param float sp_regul: regularisation in range(0;1) where "0" gives elastic
           and "1" nearly square segments
    :param int nb_classes: number of classes to be segmented(indexing from 0)
    :param (int, int, int) spacing:
    :param float gc_regul: regularisation for GC
    :return [[int]]: segmentation matrix maping each pixel into a class

    >>> np.random.seed(0)
    >>> image = np.random.random((5, 125, 150)) / 2.
    >>> image[:, :, :75] += 0.5
    >>> segm = pipe_gray3d_slic_features_gmm_graphcut(image)
    >>> segm.shape
    (5, 125, 150)
    """
    logging.info('PIPELINE Superpixels-Features-GraphCut')
    slic = seg_sp.segment_slic_img3d_gray(image,
                                          sp_size=sp_size,
                                          rltv_compact=sp_regul,
                                          space=spacing)
    # plt.imshow(segments)
    logging.info('extract segments/superpixels features.')
    # f = features.computeColourMean(image, segments)
    features, _ = seg_fts.compute_selected_features_gray3d(
        image, slic, dict_features)
    # merge features together
    logging.debug('list of features RAW: %s', repr(features.shape))
    features[np.isnan(features)] = 0

    logging.info('norm all features.')
    features, _ = seg_fts.norm_features(features)
    logging.debug('list of features NORM: %s', repr(features.shape))

    model = seg_gc.estim_class_model_gmm(features, nb_classes)
    proba = model.predict_proba(features)
    logging.debug('list of probabilities: %s', repr(proba.shape))

    # resultGraph = graphCut.segment_graph_cut_int_vals(segments, prob, gcReg)
    graph_labels = seg_gc.segment_graph_cut_general(slic, proba, image,
                                                    features, gc_regul)

    return graph_labels[slic]
    def test_count_transitions_segment(self):
        img = self.img[:, :, 0]
        annot = self.annot.astype(int)

        slic = segment_slic_img2d(img, sp_size=15, relative_compact=0.2)
        label_hist = histogram_regions_labels_norm(slic, annot)
        labels = np.argmax(label_hist, axis=1)
        trans = count_label_transitions_connected_segments({'a': slic}, {'a': labels})
        path_csv = os.path.join(PATH_OUTPUT, 'labels_transitions.csv')
        pd.DataFrame(trans).to_csv(path_csv)
        gc_regul = compute_pairwise_cost_from_transitions(trans, 10.)

        np.random.seed(0)
        features = np.tile(labels, (5, 1)).T.astype(float)
        features += np.random.random(features.shape) - 0.5

        gmm = estim_class_model_gmm(features, 4)
        proba = gmm.predict_proba(features)

        segment_graph_cut_general(slic, proba, gc_regul)