コード例 #1
0
ファイル: fiberfilter.py プロジェクト: moloney/cmp
def compute_length_array(trkfile=None, streams=None, savefname = 'lengths.npy'):
    if streams is None and not trkfile is None:
        log.info("Compute length array for fibers in %s" % trkfile)
        streams, hdr = tv.read(trkfile, as_generator = True)
        n_fibers = hdr['n_count']
        if n_fibers == 0:
            msg = "Header field n_count of trackfile %s is set to 0. No track seem to exist in this file." % trkfile
            log.error(msg)
            raise Exception(msg)
    else:
        n_fibers = len(streams)
        
    leng = np.zeros(n_fibers, dtype = np.float)
    for i,fib in enumerate(streams):
        leng[i] = util.length(fib[0])
    
    # store length array
    lefname = op.join(gconf.get_cmp_fibers(), savefname)
    np.save(lefname, leng)
    log.info("Store lengths array to: %s" % lefname)
    
    return leng
コード例 #2
0
ファイル: creatematrix.py プロジェクト: danginsburg/cmp
def cmat():
    """ Create the connection matrix for each resolution using fibers and ROIs. """

    # create the endpoints for each fibers
    en_fname = op.join(gconf.get_cmp_fibers(), 'endpoints.npy')
    en_fnamemm = op.join(gconf.get_cmp_fibers(), 'endpointsmm.npy')
    ep_fname = op.join(gconf.get_cmp_fibers(), 'lengths.npy')
    curv_fname = op.join(gconf.get_cmp_fibers(), 'meancurvature.npy')
    intrk = op.join(gconf.get_cmp_fibers(), 'streamline_filtered.trk')

    fib, hdr = nibabel.trackvis.read(intrk, False)

    # Previously, load_endpoints_from_trk() used the voxel size stored
    # in the track hdr to transform the endpoints to ROI voxel space.
    # This only works if the ROI voxel size is the same as the DSI/DTI
    # voxel size.  In the case of DTI, it is not.
    # We do, however, assume that all of the ROI images have the same
    # voxel size, so this code just loads the first one to determine
    # what it should be
    firstROIFile = op.join(gconf.get_cmp_tracto_mask_tob0(),
                           gconf.parcellation.keys()[0], 'ROI_HR_th.nii.gz')
    firstROI = nibabel.load(firstROIFile)
    roiVoxelSize = firstROI.get_header().get_zooms()
    (endpoints, endpointsmm) = create_endpoints_array(fib, roiVoxelSize)
    np.save(en_fname, endpoints)
    np.save(en_fnamemm, endpointsmm)

    # only compute curvature if required
    if gconf.compute_curvature:
        meancurv = compute_curvature_array(fib)
        np.save(curv_fname, meancurv)

    log.info("========================")

    n = len(fib)

    resolution = gconf.parcellation.keys()

    for r in resolution:

        log.info("Resolution = " + r)

        # create empty fiber label array
        fiberlabels = np.zeros((n, 2))
        final_fiberlabels = []
        final_fibers_idx = []

        # Open the corresponding ROI
        log.info("Open the corresponding ROI")
        roi_fname = op.join(gconf.get_cmp_tracto_mask_tob0(), r,
                            'ROI_HR_th.nii.gz')
        roi = nibabel.load(roi_fname)
        roiData = roi.get_data()

        # Create the matrix
        nROIs = gconf.parcellation[r]['number_of_regions']
        log.info("Create the connection matrix (%s rois)" % nROIs)
        G = nx.Graph()

        # add node information from parcellation
        gp = nx.read_graphml(gconf.parcellation[r]['node_information_graphml'])
        for u, d in gp.nodes_iter(data=True):
            G.add_node(int(u), d)

        dis = 0

        log.info("Create the connection matrix")
        for i in range(endpoints.shape[0]):

            # ROI start => ROI end
            try:
                startROI = int(roiData[endpoints[i, 0, 0], endpoints[i, 0, 1],
                                       endpoints[i, 0, 2]])
                endROI = int(roiData[endpoints[i, 1, 0], endpoints[i, 1, 1],
                                     endpoints[i, 1, 2]])
            except IndexError:
                log.error(
                    "AN INDEXERROR EXCEPTION OCCURED FOR FIBER %s. PLEASE CHECK ENDPOINT GENERATION"
                    % i)
                continue

            # Filter
            if startROI == 0 or endROI == 0:
                dis += 1
                fiberlabels[i, 0] = -1
                continue

            if startROI > nROIs or endROI > nROIs:
                log.debug(
                    "Start or endpoint of fiber terminate in a voxel which is labeled higher"
                )
                log.debug(
                    "than is expected by the parcellation node information.")
                log.debug("Start ROI: %i, End ROI: %i" % (startROI, endROI))
                log.debug("This needs bugfixing!")
                continue

            # Update fiber label
            # switch the rois in order to enforce startROI < endROI
            if endROI < startROI:
                tmp = startROI
                startROI = endROI
                endROI = tmp

            fiberlabels[i, 0] = startROI
            fiberlabels[i, 1] = endROI

            final_fiberlabels.append([startROI, endROI])
            final_fibers_idx.append(i)

            # Add edge to graph
            if G.has_edge(startROI, endROI):
                G.edge[startROI][endROI]['fiblist'].append(i)
            else:
                G.add_edge(startROI, endROI, fiblist=[i])

        log.info(
            "Found %i (%f percent out of %i fibers) fibers that start or terminate in a voxel which is not labeled. (orphans)"
            % (dis, dis * 100.0 / n, n))
        log.info("Valid fibers: %i (%f percent)" %
                 (n - dis, 100 - dis * 100.0 / n))

        # create a final fiber length array
        finalfiberlength = []
        for idx in final_fibers_idx:
            # compute length of fiber
            finalfiberlength.append(length(fib[idx][0]))

        # convert to array
        final_fiberlength_array = np.array(finalfiberlength)

        # make final fiber labels as array
        final_fiberlabels_array = np.array(final_fiberlabels, dtype=np.int32)

        # update edges
        # measures to add here
        for u, v, d in G.edges_iter(data=True):
            G.remove_edge(u, v)
            di = {
                'number_of_fibers': len(d['fiblist']),
            }

            # additional measures
            # compute mean/std of fiber measure
            idx = np.where((final_fiberlabels_array[:, 0] == int(u))
                           & (final_fiberlabels_array[:, 1] == int(v)))[0]

            di['fiber_length_mean'] = np.mean(final_fiberlength_array[idx])
            di['fiber_length_std'] = np.std(final_fiberlength_array[idx])

            G.add_edge(u, v, di)

        # storing network
        nx.write_gpickle(
            G, op.join(gconf.get_cmp_matrices(), 'connectome_%s.gpickle' % r))

        log.info("Storing final fiber length array")
        fiberlabels_fname = op.join(gconf.get_cmp_fibers(),
                                    'final_fiberslength_%s.npy' % str(r))
        np.save(fiberlabels_fname, final_fiberlength_array)

        log.info("Storing all fiber labels (with orphans)")
        fiberlabels_fname = op.join(gconf.get_cmp_fibers(),
                                    'filtered_fiberslabel_%s.npy' % str(r))
        np.save(
            fiberlabels_fname,
            np.array(fiberlabels, dtype=np.int32),
        )

        log.info("Storing final fiber labels (no orphans)")
        fiberlabels_noorphans_fname = op.join(
            gconf.get_cmp_fibers(), 'final_fiberlabels_%s.npy' % str(r))
        np.save(fiberlabels_noorphans_fname, final_fiberlabels_array)

        log.info("Filtering tractography - keeping only no orphan fibers")
        finalfibers_fname = op.join(gconf.get_cmp_fibers(),
                                    'streamline_final_%s.trk' % str(r))
        save_fibers(hdr, fib, finalfibers_fname, final_fibers_idx)

    log.info("Done.")
    log.info("========================")
コード例 #3
0
ファイル: creatematrix.py プロジェクト: rudolphpienaar/cmp
def cmat():
    """ Create the connection matrix for each resolution using fibers and ROIs. """

    # create the endpoints for each fibers
    en_fname = op.join(gconf.get_cmp_fibers(), "endpoints.npy")
    en_fnamemm = op.join(gconf.get_cmp_fibers(), "endpointsmm.npy")
    ep_fname = op.join(gconf.get_cmp_fibers(), "lengths.npy")
    curv_fname = op.join(gconf.get_cmp_fibers(), "meancurvature.npy")
    intrk = op.join(gconf.get_cmp_fibers(), "streamline_filtered.trk")

    fib, hdr = nibabel.trackvis.read(intrk, False)

    # Previously, load_endpoints_from_trk() used the voxel size stored
    # in the track hdr to transform the endpoints to ROI voxel space.
    # This only works if the ROI voxel size is the same as the DSI/DTI
    # voxel size.  In the case of DTI, it is not.
    # We do, however, assume that all of the ROI images have the same
    # voxel size, so this code just loads the first one to determine
    # what it should be
    firstROIFile = op.join(gconf.get_cmp_tracto_mask_tob0(), gconf.parcellation.keys()[0], "ROI_HR_th.nii.gz")
    firstROI = nibabel.load(firstROIFile)
    roiVoxelSize = firstROI.get_header().get_zooms()
    (endpoints, endpointsmm) = create_endpoints_array(fib, roiVoxelSize)
    np.save(en_fname, endpoints)
    np.save(en_fnamemm, endpointsmm)

    # only compute curvature if required
    if gconf.compute_curvature:
        meancurv = compute_curvature_array(fib)
        np.save(curv_fname, meancurv)

    log.info("========================")

    n = len(fib)

    resolution = gconf.parcellation.keys()

    for r in resolution:

        log.info("Resolution = " + r)

        # create empty fiber label array
        fiberlabels = np.zeros((n, 2))
        final_fiberlabels = []
        final_fibers_idx = []

        # Open the corresponding ROI
        log.info("Open the corresponding ROI")
        roi_fname = op.join(gconf.get_cmp_tracto_mask_tob0(), r, "ROI_HR_th.nii.gz")
        roi = nibabel.load(roi_fname)
        roiData = roi.get_data()

        # Create the matrix
        nROIs = gconf.parcellation[r]["number_of_regions"]
        log.info("Create the connection matrix (%s rois)" % nROIs)
        G = nx.Graph()

        # add node information from parcellation
        gp = nx.read_graphml(gconf.parcellation[r]["node_information_graphml"])
        for u, d in gp.nodes_iter(data=True):
            G.add_node(int(u), d)
            # compute a position for the node based on the mean position of the
            # ROI in voxel coordinates (segmentation volume )
            G.node[int(u)]["dn_position"] = str(
                tuple(np.mean(np.where(roiData == int(d["dn_correspondence_id"])), axis=1))
            )

        dis = 0

        # prepare: compute the measures
        t = [c[0] for c in fib]
        h = np.array(t, dtype=np.object)
        if gconf.diffusion_imaging_model == "DSI":
            mmap = {}
            if gconf.connection_P0:
                mmap["P0"] = "dsi_P0.nii.gz"
            if gconf.connection_gfa:
                mmap["gfa"] = "dsi_gfa.nii.gz"
            if gconf.connection_kurtosis:
                mmap["kurtosis"] = "dsi_kurtosis.nii.gz"
            if gconf.connection_skewness:
                mmap["skewness"] = "dsi_skewness.nii.gz"
            mmapdata = {}
            for k, v in mmap.items():
                da = nibabel.load(op.join(gconf.get_cmp_scalars(), v))
                mmapdata[k] = (da.get_data(), da.get_header().get_zooms())

        elif gconf.diffusion_imaging_model == "DTI":
            mmap = {}
            if gconf.connection_adc:
                mmap["adc"] = "dti_adc.nii.gz"
            if gconf.connection_fa:
                mmap["fa"] = "dti_fa.nii.gz"
            mmapdata = {}
            for k, v in mmap.items():
                print "Read volume", v
                da = nibabel.load(op.join(gconf.get_cmp_scalars(), v))
                mmapdata[k] = (da.get_data(), da.get_header().get_zooms())

        elif gconf.diffusion_imaging_model == "QBALL":
            mmap = {}
            if gconf.connection_P0:
                mmap["P0"] = "hardi_P0.nii.gz"
            if gconf.connection_gfa:
                mmap["gfa"] = "hardi_gfa.nii.gz"
            if gconf.connection_kurtosis:
                mmap["kurtosis"] = "hardi_kurtosis.nii.gz"
            if gconf.connection_skewness:
                mmap["skewness"] = "hardi_skewness.nii.gz"
            mmapdata = {}
            for k, v in mmap.items():
                da = nibabel.load(op.join(gconf.get_cmp_scalars(), v))
                mmapdata[k] = (da.get_data(), da.get_header().get_zooms())

        log.info("Create the connection matrix")
        pc = -1
        for i in range(n):

            # Percent counter
            pcN = int(round(float(100 * i) / n))
            if pcN > pc and pcN % 1 == 0:
                pc = pcN
                log.info("%4.0f%%" % (pc))

            # ROI start => ROI end
            try:
                startROI = int(roiData[endpoints[i, 0, 0], endpoints[i, 0, 1], endpoints[i, 0, 2]])
                endROI = int(roiData[endpoints[i, 1, 0], endpoints[i, 1, 1], endpoints[i, 1, 2]])
            except IndexError:
                log.info(
                    "An index error occured for fiber %s. This means that the fiber start or endpoint is outside the volume. Continue."
                    % i
                )
                continue

            # Filter
            if startROI == 0 or endROI == 0:
                dis += 1
                fiberlabels[i, 0] = -1
                continue

            if startROI > nROIs or endROI > nROIs:
                log.debug("Start or endpoint of fiber terminate in a voxel which is labeled higher")
                log.debug("than is expected by the parcellation node information.")
                log.debug("Start ROI: %i, End ROI: %i" % (startROI, endROI))
                log.debug("This needs bugfixing!")
                continue

            # Update fiber label
            # switch the rois in order to enforce startROI < endROI
            if endROI < startROI:
                tmp = startROI
                startROI = endROI
                endROI = tmp

            fiberlabels[i, 0] = startROI
            fiberlabels[i, 1] = endROI

            final_fiberlabels.append([startROI, endROI])
            final_fibers_idx.append(i)

            # Add edge to graph
            if G.has_edge(startROI, endROI):
                G.edge[startROI][endROI]["fiblist"].append(i)
            else:
                G.add_edge(startROI, endROI, fiblist=[i])

        log.info(
            "Found %i (%f percent out of %i fibers) fibers that start or terminate in a voxel which is not labeled. (orphans)"
            % (dis, dis * 100.0 / n, n)
        )
        log.info("Valid fibers: %i (%f percent)" % (n - dis, 100 - dis * 100.0 / n))

        # create a final fiber length array
        finalfiberlength = []
        for idx in final_fibers_idx:
            # compute length of fiber
            finalfiberlength.append(length(fib[idx][0]))

        # convert to array
        final_fiberlength_array = np.array(finalfiberlength)

        # make final fiber labels as array
        final_fiberlabels_array = np.array(final_fiberlabels, dtype=np.int32)

        # update edges
        # measures to add here
        for u, v, d in G.edges_iter(data=True):
            # print "From To Region ", u,v
            G.remove_edge(u, v)
            di = {"number_of_fibers": len(d["fiblist"])}

            idx = np.where((final_fiberlabels_array[:, 0] == int(u)) & (final_fiberlabels_array[:, 1] == int(v)))[0]
            di["fiber_length_mean"] = float(np.mean(final_fiberlength_array[idx]))
            di["fiber_length_std"] = float(np.std(final_fiberlength_array[idx]))

            # this is indexed into the fibers that are valid in the sense of touching start
            # and end roi and not going out of the volume
            idx_valid = np.where((fiberlabels[:, 0] == int(u)) & (fiberlabels[:, 1] == int(v)))[0]
            for k, vv in mmapdata.items():
                val = []
                for i in idx_valid:
                    # retrieve indices
                    try:
                        idx2 = (h[i] / vv[1]).astype(np.uint32)
                        val.append(vv[0][idx2[:, 0], idx2[:, 1], idx2[:, 2]])
                    except IndexError, e:
                        print "Index error occured when trying extract scalar values for measure", k
                        print "--> Discard fiber with index", i, "Exception: ", e
                        print "----"

                da = np.concatenate(val)
                di[k + "_mean"] = float(da.mean())
                di[k + "_std"] = float(da.std())
                del da
                del val

            G.add_edge(u, v, di)

        # storing network
        nx.write_gpickle(G, op.join(gconf.get_cmp_matrices(), "connectome_%s.gpickle" % r))

        log.info("Storing final fiber length array")
        fiberlabels_fname = op.join(gconf.get_cmp_fibers(), "final_fiberslength_%s.npy" % str(r))
        np.save(fiberlabels_fname, final_fiberlength_array)

        log.info("Storing all fiber labels (with orphans)")
        fiberlabels_fname = op.join(gconf.get_cmp_fibers(), "filtered_fiberslabel_%s.npy" % str(r))
        np.save(fiberlabels_fname, np.array(fiberlabels, dtype=np.int32))

        log.info("Storing final fiber labels (no orphans)")
        fiberlabels_noorphans_fname = op.join(gconf.get_cmp_fibers(), "final_fiberlabels_%s.npy" % str(r))
        np.save(fiberlabels_noorphans_fname, final_fiberlabels_array)

        log.info("Filtering tractography - keeping only no orphan fibers")
        finalfibers_fname = op.join(gconf.get_cmp_fibers(), "streamline_final_%s.trk" % str(r))
        save_fibers(hdr, fib, finalfibers_fname, final_fibers_idx)
コード例 #4
0
ファイル: creatematrix.py プロジェクト: moloney/cmp
def cmat():
    """ Create the connection matrix for each resolution using fibers and ROIs. """

    # create the endpoints for each fibers
    en_fname = op.join(gconf.get_cmp_fibers(), 'endpoints.npy')
    en_fnamemm = op.join(gconf.get_cmp_fibers(), 'endpointsmm.npy')
    ep_fname = op.join(gconf.get_cmp_fibers(), 'lengths.npy')
    curv_fname = op.join(gconf.get_cmp_fibers(), 'meancurvature.npy')
    intrk = op.join(gconf.get_cmp_fibers(), 'streamline_filtered.trk')

    fib, hdr = nibabel.trackvis.read(intrk, False)

    # Previously, load_endpoints_from_trk() used the voxel size stored
    # in the track hdr to transform the endpoints to ROI voxel space.
    # This only works if the ROI voxel size is the same as the DSI/DTI
    # voxel size.  In the case of DTI, it is not.
    # We do, however, assume that all of the ROI images have the same
    # voxel size, so this code just loads the first one to determine
    # what it should be
    firstROIFile = op.join(gconf.get_cmp_tracto_mask_tob0(),
                           gconf.parcellation.keys()[0], 'ROIv_HR_th.nii.gz')
    firstROI = nibabel.load(firstROIFile)
    roiVoxelSize = firstROI.get_header().get_zooms()
    (endpoints, endpointsmm) = create_endpoints_array(fib, roiVoxelSize)
    np.save(en_fname, endpoints)
    np.save(en_fnamemm, endpointsmm)

    # only compute curvature if required
    if gconf.compute_curvature:
        meancurv = compute_curvature_array(fib)
        np.save(curv_fname, meancurv)

    log.info("========================")

    n = len(fib)

    resolution = gconf.parcellation.keys()

    for r in resolution:

        log.info("Resolution = " + r)

        # create empty fiber label array
        fiberlabels = np.zeros((n, 2))
        final_fiberlabels = []
        final_fibers_idx = []

        # Open the corresponding ROI
        log.info("Open the corresponding ROI")
        roi_fname = op.join(gconf.get_cmp_tracto_mask_tob0(), r,
                            'ROIv_HR_th.nii.gz')
        roi = nibabel.load(roi_fname)
        roiData = roi.get_data()

        # Create the matrix
        nROIs = gconf.parcellation[r]['number_of_regions']
        log.info("Create the connection matrix (%s rois)" % nROIs)
        G = nx.Graph()

        # add node information from parcellation
        gp = nx.read_graphml(gconf.parcellation[r]['node_information_graphml'])
        for u, d in gp.nodes_iter(data=True):
            G.add_node(int(u), d)
            # compute a position for the node based on the mean position of the
            # ROI in voxel coordinates (segmentation volume )
            G.node[int(u)]['dn_position'] = tuple(
                np.mean(np.where(roiData == int(d["dn_correspondence_id"])),
                        axis=1))

        dis = 0

        # prepare: compute the measures
        t = [c[0] for c in fib]
        h = np.array(t, dtype=np.object)
        if gconf.diffusion_imaging_model == 'DSI':
            mmap = {}
            if gconf.connection_P0:
                mmap['P0'] = 'dsi_P0.nii.gz'
            if gconf.connection_gfa:
                mmap['gfa'] = 'dsi_gfa.nii.gz'
            if gconf.connection_kurtosis:
                mmap['kurtosis'] = 'dsi_kurtosis.nii.gz'
            if gconf.connection_skewness:
                mmap['skewness'] = 'dsi_skewness.nii.gz'
            mmapdata = {}
            for k, v in mmap.items():
                da = nibabel.load(op.join(gconf.get_cmp_scalars(), v))
                mmapdata[k] = (da.get_data(), da.get_header().get_zooms())

        elif gconf.diffusion_imaging_model == 'DTI':
            mmap = {}
            if gconf.connection_adc:
                mmap['adc'] = 'dti_adc.nii.gz'
            if gconf.connection_fa:
                mmap['fa'] = 'dti_fa.nii.gz'
            mmapdata = {}
            for k, v in mmap.items():
                print "Read volume", v
                da = nibabel.load(op.join(gconf.get_cmp_scalars(), v))
                mmapdata[k] = (da.get_data(), da.get_header().get_zooms())

        elif gconf.diffusion_imaging_model == 'QBALL':
            mmap = {}
            if gconf.connection_P0:
                mmap['P0'] = 'hardi_P0.nii.gz'
            if gconf.connection_gfa:
                mmap['gfa'] = 'hardi_gfa.nii.gz'
            if gconf.connection_kurtosis:
                mmap['kurtosis'] = 'hardi_kurtosis.nii.gz'
            if gconf.connection_skewness:
                mmap['skewness'] = 'hardi_skewness.nii.gz'
            mmapdata = {}
            for k, v in mmap.items():
                da = nibabel.load(op.join(gconf.get_cmp_scalars(), v))
                mmapdata[k] = (da.get_data(), da.get_header().get_zooms())

        log.info("Create the connection matrix")
        pc = -1
        for i in range(n):  # n: number of fibers

            # Percent counter
            pcN = int(round(float(100 * i) / n))
            if pcN > pc and pcN % 1 == 0:
                pc = pcN
                log.info('%4.0f%%' % (pc))

            # ROI start => ROI end
            try:
                startROI = int(roiData[
                    endpoints[i, 0, 0], endpoints[i, 0, 1],
                    endpoints[i, 0,
                              2]])  # endpoints from create_endpoints_array
                endROI = int(roiData[endpoints[i, 1, 0], endpoints[i, 1, 1],
                                     endpoints[i, 1, 2]])
            except IndexError:
                log.info(
                    "An index error occured for fiber %s. This means that the fiber start or endpoint is outside the volume. Continue."
                    % i)
                continue

            # Filter
            if startROI == 0 or endROI == 0:
                dis += 1
                fiberlabels[i, 0] = -1
                continue

            if startROI > nROIs or endROI > nROIs:
                log.debug(
                    "Start or endpoint of fiber terminate in a voxel which is labeled higher"
                )
                log.debug(
                    "than is expected by the parcellation node information.")
                log.debug("Start ROI: %i, End ROI: %i" % (startROI, endROI))
                log.debug("This needs bugfixing!")
                continue

            # Update fiber label
            # switch the rois in order to enforce startROI < endROI
            if endROI < startROI:
                tmp = startROI
                startROI = endROI
                endROI = tmp

            fiberlabels[i, 0] = startROI
            fiberlabels[i, 1] = endROI

            final_fiberlabels.append([startROI, endROI])
            final_fibers_idx.append(i)

            # Add edge to graph
            if G.has_edge(startROI, endROI):
                G.edge[startROI][endROI]['fiblist'].append(i)
            else:
                G.add_edge(startROI, endROI, fiblist=[i])

        log.info(
            "Found %i (%f percent out of %i fibers) fibers that start or terminate in a voxel which is not labeled. (orphans)"
            % (dis, dis * 100.0 / n, n))
        log.info("Valid fibers: %i (%f percent)" %
                 (n - dis, 100 - dis * 100.0 / n))

        # create a final fiber length array
        finalfiberlength = []
        for idx in final_fibers_idx:
            # compute length of fiber
            finalfiberlength.append(length(fib[idx][0]))

        # convert to array
        final_fiberlength_array = np.array(finalfiberlength)

        # make final fiber labels as array
        final_fiberlabels_array = np.array(final_fiberlabels, dtype=np.int32)

        # update edges
        # measures to add here
        for u, v, d in G.edges_iter(data=True):
            G.remove_edge(u, v)
            di = {
                'number_of_fibers': len(d['fiblist']),
            }

            # additional measures
            # compute mean/std of fiber measure
            idx = np.where((final_fiberlabels_array[:, 0] == int(u))
                           & (final_fiberlabels_array[:, 1] == int(v)))[0]
            di['fiber_length_mean'] = float(
                np.mean(final_fiberlength_array[idx]))
            di['fiber_length_std'] = float(np.std(
                final_fiberlength_array[idx]))

            # this is indexed into the fibers that are valid in the sense of touching start
            # and end roi and not going out of the volume
            idx_valid = np.where((fiberlabels[:, 0] == int(u))
                                 & (fiberlabels[:, 1] == int(v)))[0]
            for k, vv in mmapdata.items():
                val = []
                for i in idx_valid:
                    # retrieve indices
                    try:
                        idx2 = (h[i] / vv[1]).astype(np.uint32)
                        val.append(vv[0][idx2[:, 0], idx2[:, 1], idx2[:, 2]])
                    except IndexError, e:
                        print "Index error occured when trying extract scalar values for measure", k
                        print "--> Discard fiber with index", i, "Exception: ", e
                        print "----"

                da = np.concatenate(val)
                di[k + '_mean'] = float(da.mean())
                di[k + '_std'] = float(da.std())
                del da
                del val

            G.add_edge(u, v, di)

        # storing network
        nx.write_gpickle(
            G, op.join(gconf.get_cmp_matrices(), 'connectome_%s.gpickle' % r))

        log.info("Storing final fiber length array")
        fiberlabels_fname = op.join(gconf.get_cmp_fibers(),
                                    'final_fiberslength_%s.npy' % str(r))
        np.save(fiberlabels_fname, final_fiberlength_array)

        log.info("Storing all fiber labels (with orphans)")
        fiberlabels_fname = op.join(gconf.get_cmp_fibers(),
                                    'filtered_fiberslabel_%s.npy' % str(r))
        np.save(
            fiberlabels_fname,
            np.array(fiberlabels, dtype=np.int32),
        )

        log.info("Storing final fiber labels (no orphans)")
        fiberlabels_noorphans_fname = op.join(
            gconf.get_cmp_fibers(), 'final_fiberlabels_%s.npy' % str(r))
        np.save(fiberlabels_noorphans_fname, final_fiberlabels_array)

        log.info("Filtering tractography - keeping only no orphan fibers")
        finalfibers_fname = op.join(gconf.get_cmp_fibers(),
                                    'streamline_final_%s.trk' % str(r))
        save_fibers(hdr, fib, finalfibers_fname, final_fibers_idx)
コード例 #5
0
ファイル: creatematrix.py プロジェクト: danginsburg/cmp
def cmat(): 
    """ Create the connection matrix for each resolution using fibers and ROIs. """
              
    # create the endpoints for each fibers
    en_fname  = op.join(gconf.get_cmp_fibers(), 'endpoints.npy')
    en_fnamemm  = op.join(gconf.get_cmp_fibers(), 'endpointsmm.npy')
    ep_fname  = op.join(gconf.get_cmp_fibers(), 'lengths.npy')
    curv_fname  = op.join(gconf.get_cmp_fibers(), 'meancurvature.npy')
    intrk = op.join(gconf.get_cmp_fibers(), 'streamline_filtered.trk')

    fib, hdr    = nibabel.trackvis.read(intrk, False)
    
    # Previously, load_endpoints_from_trk() used the voxel size stored
    # in the track hdr to transform the endpoints to ROI voxel space.
    # This only works if the ROI voxel size is the same as the DSI/DTI
    # voxel size.  In the case of DTI, it is not.  
    # We do, however, assume that all of the ROI images have the same
    # voxel size, so this code just loads the first one to determine
    # what it should be
    firstROIFile = op.join(gconf.get_cmp_tracto_mask_tob0(), 
                           gconf.parcellation.keys()[0],
                           'ROI_HR_th.nii.gz')
    firstROI = nibabel.load(firstROIFile)
    roiVoxelSize = firstROI.get_header().get_zooms()
    (endpoints,endpointsmm) = create_endpoints_array(fib, roiVoxelSize)
    np.save(en_fname, endpoints)
    np.save(en_fnamemm, endpointsmm)

    # only compute curvature if required
    if gconf.compute_curvature:
        meancurv = compute_curvature_array(fib)
        np.save(curv_fname, meancurv)
    
    log.info("========================")
    
    n = len(fib)
    
    resolution = gconf.parcellation.keys()

    for r in resolution:
        
        log.info("Resolution = "+r)
        
        # create empty fiber label array
        fiberlabels = np.zeros( (n, 2) )
        final_fiberlabels = []
        final_fibers_idx = []
        
        # Open the corresponding ROI
        log.info("Open the corresponding ROI")
        roi_fname = op.join(gconf.get_cmp_tracto_mask_tob0(), r, 'ROI_HR_th.nii.gz')
        roi       = nibabel.load(roi_fname)
        roiData   = roi.get_data()
      
        # Create the matrix
        nROIs = gconf.parcellation[r]['number_of_regions']
        log.info("Create the connection matrix (%s rois)" % nROIs)
        G     = nx.Graph()

        # add node information from parcellation
        gp = nx.read_graphml(gconf.parcellation[r]['node_information_graphml'])
        for u,d in gp.nodes_iter(data=True):
            G.add_node(int(u), d)

        dis = 0
        
        log.info("Create the connection matrix")
        for i in range(endpoints.shape[0]):
    
            # ROI start => ROI end
            try:
                startROI = int(roiData[endpoints[i, 0, 0], endpoints[i, 0, 1], endpoints[i, 0, 2]])
                endROI   = int(roiData[endpoints[i, 1, 0], endpoints[i, 1, 1], endpoints[i, 1, 2]])
            except IndexError:
                log.error("AN INDEXERROR EXCEPTION OCCURED FOR FIBER %s. PLEASE CHECK ENDPOINT GENERATION" % i)
                continue
            
            # Filter
            if startROI == 0 or endROI == 0:
                dis += 1
                fiberlabels[i,0] = -1
                continue
            
            if startROI > nROIs or endROI > nROIs:
                log.debug("Start or endpoint of fiber terminate in a voxel which is labeled higher")
                log.debug("than is expected by the parcellation node information.")
                log.debug("Start ROI: %i, End ROI: %i" % (startROI, endROI))
                log.debug("This needs bugfixing!")
                continue
            
            # Update fiber label
            # switch the rois in order to enforce startROI < endROI
            if endROI < startROI:
                tmp = startROI
                startROI = endROI
                endROI = tmp

            fiberlabels[i,0] = startROI
            fiberlabels[i,1] = endROI

            final_fiberlabels.append( [ startROI, endROI ] )
            final_fibers_idx.append(i)


            # Add edge to graph
            if G.has_edge(startROI, endROI):
                G.edge[startROI][endROI]['fiblist'].append(i)
            else:
                G.add_edge(startROI, endROI, fiblist   = [i])
                
        log.info("Found %i (%f percent out of %i fibers) fibers that start or terminate in a voxel which is not labeled. (orphans)" % (dis, dis*100.0/n, n) )
        log.info("Valid fibers: %i (%f percent)" % (n-dis, 100 - dis*100.0/n) )

        # create a final fiber length array
        finalfiberlength = []
        for idx in final_fibers_idx:
            # compute length of fiber
            finalfiberlength.append( length(fib[idx][0]) )

        # convert to array
        final_fiberlength_array = np.array( finalfiberlength )
        
        # make final fiber labels as array
        final_fiberlabels_array = np.array(final_fiberlabels, dtype = np.int32)

        # update edges
        # measures to add here
        for u,v,d in G.edges_iter(data=True):
            G.remove_edge(u,v)
            di = { 'number_of_fibers' : len(d['fiblist']), }
            
            # additional measures
            # compute mean/std of fiber measure
            idx = np.where( (final_fiberlabels_array[:,0] == int(u)) & (final_fiberlabels_array[:,1] == int(v)) )[0]

            di['fiber_length_mean'] = np.mean(final_fiberlength_array[idx])
            di['fiber_length_std'] = np.std(final_fiberlength_array[idx])

            G.add_edge(u,v, di)

        # storing network
        nx.write_gpickle(G, op.join(gconf.get_cmp_matrices(), 'connectome_%s.gpickle' % r))

        log.info("Storing final fiber length array")
        fiberlabels_fname  = op.join(gconf.get_cmp_fibers(), 'final_fiberslength_%s.npy' % str(r))
        np.save(fiberlabels_fname, final_fiberlength_array)

        log.info("Storing all fiber labels (with orphans)")
        fiberlabels_fname  = op.join(gconf.get_cmp_fibers(), 'filtered_fiberslabel_%s.npy' % str(r))
        np.save(fiberlabels_fname, np.array(fiberlabels, dtype = np.int32), )

        log.info("Storing final fiber labels (no orphans)")
        fiberlabels_noorphans_fname  = op.join(gconf.get_cmp_fibers(), 'final_fiberlabels_%s.npy' % str(r))
        np.save(fiberlabels_noorphans_fname, final_fiberlabels_array)

        log.info("Filtering tractography - keeping only no orphan fibers")
        finalfibers_fname = op.join(gconf.get_cmp_fibers(), 'streamline_final_%s.trk' % str(r))
        save_fibers(hdr, fib, finalfibers_fname, final_fibers_idx)

    log.info("Done.")
    log.info("========================")