Exemple #1
0
def extract_clusters_from_diam(T,XYZ,th,diam,k=18):
    """
    Extract clusters from a statistical map
    under diameter constraint
    and above given threshold
    In:  T      (p)     statistical map
         XYZ    (3,p)   voxels coordinates
         th     <float> minimum threshold
         diam   <int>   maximal diameter (in voxels)
         k      <int>   the number of neighbours considered. (6,18 or 26)
    Out: labels (p)     cluster labels
    """
    CClabels = extract_clusters_from_thresh(T,XYZ,th,k)
    nCC = CClabels.max() + 1
    labels = -np.ones(len(CClabels),int)
    clust_label = 0
    for i in xrange(nCC):
        #print "Searching connected component ", i, " out of ", nCC
        I = np.where(CClabels==i)[0]
        extCC = len(I)
        if extCC <= (diam+1)**3:
            diamCC = max_dist(XYZ,I,I)
        else:
            diamCC = diam+1
        if diamCC <= diam:
            labels[I] = np.zeros(extCC,int) + clust_label
            #print "cluster ", clust_label, ", diam = ", diamCC
            #print "ext = ", len(I), ", diam = ", max_dist(XYZ,I,I)
            clust_label += 1
        else:
            # build the field
            p = len(T[I])
            F = Field(p)
            F.from_3d_grid(np.transpose(XYZ[:,I]),k)
            F.set_field(np.reshape(T[I],(p,1)))
            # compute the blobs
            idx, height, parent,label = F.threshold_bifurcations(0,th)
            nidx = np.size(idx)
            #root = nidx-1
            root = np.where(np.arange(nidx)==parent)[0]
            # Can constraint be met within current region?
            Imin = I[T[I]>=height[root]]
            extmin = len(Imin)
            if extmin <= (diam+1)**3:
                dmin = max_dist(XYZ,Imin,Imin)
            else:
                dmin = diam+1
            if dmin <= diam:# If so, search for the largest cluster meeting the constraint
                Iclust = Imin # Smallest cluster
                J = I[T[I]<height[root]] # Remaining voxels
                argsortTJ = np.argsort(T[J])[::-1] # Sorted by decreasing T values
                l = 0
                L = np.array([J[argsortTJ[l]]],int)
                diameter = dmin
                new_diameter = max(dmin,max_dist(XYZ,Iclust,L))
                while new_diameter <= diam:
                    #print "diameter = " + str(new_diameter)
                    #sys.stdout.flush()
                    Iclust = np.concatenate((Iclust,L))
                    diameter = new_diameter
                    #print "diameter = ", diameter
                    l += 1
                    L = np.array([J[argsortTJ[l]]],int)
                    new_diameter = max(diameter,max_dist(XYZ,Iclust,L))
                labels[Iclust] = np.zeros(len(Iclust),int) + clust_label
                #print "cluster ", clust_label, ", diam = ", diameter
                #print "ext = ", len(Iclust), ", diam = ", max_dist(XYZ,Iclust,Iclust)
                clust_label += 1
            else:# If not, search inside sub-regions
                #print "Searching inside sub-regions "
                Irest = I[T[I]>height[root]]
                rest_labels = extract_clusters_from_diam(T[Irest],XYZ[:,Irest],th,diam,k)
                rest_labels[rest_labels>=0] += clust_label
                clust_label = rest_labels.max() + 1
                labels[Irest] = rest_labels
    return labels
"""
print __doc__

import numpy as np
import os
from nipy.io.imageformats import load, save, Nifti1Image 
from nipy.neurospin.graph.field import Field
import get_data_light
import tempfile
data_dir = get_data_light.get_it()

# paths
swd = tempfile.mkdtemp()
input_image = os.path.join(data_dir, 'spmT_0029.nii.gz')
mask_image = os.path.join(data_dir, 'mask.nii.gz')

mask = load(mask_image).get_data()>0
ijk = np.array(np.where(mask)).T
nvox = ijk.shape[0]
data = load(input_image).get_data()[mask]
image_field = Field(nvox)
image_field.from_3d_grid(ijk, k=6)
image_field.set_field(data)
u = image_field.ward(100)

label_image = os.path.join(swd, 'label.nii')
wdata = mask - 1
wdata[mask] = u
save(Nifti1Image(wdata, load(mask_image).get_affine()), label_image)
print "Label image written in %s"  % label_image