Beispiel #1
0
def _make_label_connectivity():
    labels = [Label(vertices=np.arange(3), hemi='lh', name='Label1'),
              Label(vertices=np.arange(3, 6), hemi='lh', name='Label2'),
              Label(vertices=np.arange(6, 9), hemi='lh', name='Label3')]

    pairs = [[0, 0, 1], [1, 2, 2]]
    data = np.arange(len(pairs[1]))

    return LabelConnectivity(data, pairs, labels)
Beispiel #2
0
def _generate_labels(vertices, n_labels):
    vert_lh, vert_rh = vertices
    n_lh_chunck = len(vert_lh)//(n_labels // 2)
    n_rh_chunck = len(vert_rh)//(n_labels - n_labels // 2)

    labels_lh = [Label(vertices=vert_lh[x:x+n_lh_chunck], hemi='lh',
                       name='Label' + str(x))
                 for x in range(0, len(vert_lh), n_lh_chunck)]

    labels_rh = [Label(vertices=vert_rh[x:x+n_rh_chunck], hemi='rh',
                       name='Label' + str(x))
                 for x in range(0, len(vert_rh), n_lh_chunck)]

    return labels_lh + labels_rh
Beispiel #3
0
def plot_roi(
    hemi,
    labels,
    color,
    annotation="HCPMMP1",
    view="parietal",
    fs_dir=os.environ["SUBJECTS_DIR"],
    subject_id="S04",
    surf="inflated",
):
    import matplotlib
    import os
    import glob
    from surfer import Brain
    from mne import Label

    color = np.array(matplotlib.colors.to_rgba(color))

    brain = Brain(subject_id, hemi, surf, offscreen=False)
    labels = [label.replace("-rh", "").replace("-lh", "") for label in labels]
    # First select all label files

    label_names = glob.glob(
        os.path.join(fs_dir, subject_id, "label", "lh*.label"))
    label_names = [
        label for label in label_names if any([l in label for l in labels])
    ]

    for label in label_names:
        brain.add_label(label, color=color)

    # Now go for annotations
    from nibabel.freesurfer import io

    ids, colors, annot_names = io.read_annot(
        os.path.join(fs_dir, subject_id, "label", "lh.%s.annot" % annotation),
        orig_ids=True,
    )

    for i, alabel in enumerate(annot_names):
        if any([label in alabel.decode("utf-8") for label in labels]):
            label_id = colors[i, -1]
            vertices = np.where(ids == label_id)[0]
            l = Label(np.sort(vertices), hemi="lh")
            brain.add_label(l, color=color)
    brain.show_view(view)
    return brain.screenshot()
Beispiel #4
0
def extract_roi(stc, src, label=None, thresh=0.5):
    """Extract a functional ROI.

    Parameters
    ----------
    stc : instance of SourceEstimate
        The source estimate data. The maximum positive peak will be selected.
        If you want the maximum negative peak, consider passing
        abs(stc) or -stc.
    src : instance of SourceSpaces
        The associated source space.
    label : instance of Label | None
        The label within which to select the peak.
        Can be None to use the entire STC.
    thresh : float
        Threshold value (relative to the peak value) above which vertices
        will be taken.

    Returns
    -------
    roi : instance of Label
        The functional ROI.
    """
    assert isinstance(stc, SourceEstimate)
    if label is None:
        stc_label = stc.copy()
    else:
        stc_label = stc.in_label(label)
    del label
    max_vidx, max_tidx = np.unravel_index(np.argmax(stc_label.data),
                                          stc_label.data.shape)
    max_val = stc_label.data[max_vidx, max_tidx]
    if max_vidx < len(stc_label.vertices[0]):
        hemi = 'lh'
        max_vert = stc_label.vertices[0][max_vidx]
        max_vidx = list(stc.vertices[0]).index(max_vert)
    else:
        hemi = 'rh'
        max_vert = stc_label.vertices[1][max_vidx - len(stc_label.vertices[0])]
        max_vidx = list(stc.vertices[1]).index(max_vert)
        max_vidx += len(stc.vertices[0])
    del stc_label
    assert max_val == stc.data[max_vidx, max_tidx]

    # Get contiguous vertices within 50%
    threshold = max_val * thresh
    connectivity = spatial_src_adjacency(src, verbose='error')  # holes
    _, clusters, _, _ = spatio_temporal_cluster_1samp_test(
        np.array([stc.data]), threshold, n_permutations=1,
        stat_fun=lambda x: x.mean(0), tail=1,
        connectivity=connectivity)
    for cluster in clusters:
        if max_vidx in cluster[0] and max_tidx in cluster[1]:
            break  # found our cluster
    else:  # in case we did not "break"
        raise RuntimeError('Clustering failed somehow!')
    if hemi == 'lh':
        verts = stc.vertices[0][cluster]
    else:
        verts = stc.vertices[1][cluster - len(stc.vertices[0])]
    func_label = Label(verts, hemi=hemi, subject=stc.subject)
    func_label = func_label.fill(src)
    return func_label, max_vert, max_vidx, max_tidx
Beispiel #5
0
def get_surface_labels(surface, texture, subject, hemi):
    """get areas on the surface

    Parameters
    ----------
    surface : instance of Surface
    texture : str | array
        Array to get areas or the filename to get this
    subject : str
        Name of the subject
    hemi : 'lh' | 'rh'
        Name of the hemisphere
    fname_atlas : str | None
        Filename for area atlas
    fname_color : str | None
        Filename for area color

    Returns
    -------
    labels : instance of mne.Labels
        MarsAtlas labels for surface sources
    """

    labels = []

    rr = surface['rr']
    # normals = surface['nn']

    # Get texture with gifti format (BainVisa)= labels of MarsAtlas
    if isinstance(texture, str):
        giftiImage = gifti.giftiio.read(texture)
        base_values = giftiImage.darrays[0].data

    else:
        base_values = texture

    values = base_values

    # Get parcels and count the number of nodes in each parcel (count)
    parcels, counts = np.unique(values, return_counts=True)

    # Get parcels information
    # info = read_texture_info(fname_atlas, hemi)
    from bv2mne.marsatlas_parcels import ma_parcels
    info = ma_parcels

    # Get triangles for whole surface
    triangles = surface['tris']
    total_nodes = 0
    for pos, val in enumerate(parcels):

        name_process = info.get(parcels[pos], False)
        if not name_process:
            name = 'no_name'
            # lobe = 'no_name'
        else:
            name = name_process[0]
            # lobe = name_process[1]

        # Find index for nodes of the parcel
        ind = np.where(values == val)

        # Keep only those nodes and pos of parcel that are associated with a
        # face (triangle) in its parcel
        # get triangles where points of the parcel are
        ix = np.in1d(triangles.ravel(), ind).reshape(triangles.shape)

        # Counting the number of True per lines --> True : 1 , False : 0
        # to know how many points of the parcel are in each face
        counts = ix.sum(1)

        # Indices of each triangles that contains 3 points of the parcel
        ind_all = np.where(counts == 3)
        tris_cour = triangles[ind_all]

        # Select nodes that are connected through triangles
        nodes = np.unique(tris_cour)
        iy = np.in1d(ind, nodes)
        ind_n = np.where(iy)
        ind_n = ind_n[0]
        ind = ind[0]

        # Positions and normals
        rr_parcel = rr[ind[ind_n]]
        # normals_parcel = normals[ind[ind_n]]

        # Textures (values)
        values_parcel = values[ind[ind_n]]
        # values_parcel = ind                 ############check#############

        # Locations in meters
        # rr_parcel = rr_parcel * 1e-3

        # Number of nodes
        nodes, tmp = rr_parcel.shape
        # vertex_ind = np.arange(total_nodes, total_nodes + nodes, 1)
        vertex_ind = ind[ind_n]
        total_nodes = total_nodes + nodes  # (was =+???)

        label = Label(vertices=vertex_ind,
                      pos=rr_parcel,
                      values=values_parcel,
                      hemi=hemi,
                      comment=name,
                      name=name,
                      filename=None,
                      subject=subject,
                      verbose=None)

        labels.append(label)

    return labels
def test_extract_label_time_course(kind, vector):
    """Test extraction of label time courses from (Mixed)SourceEstimate."""
    n_stcs = 3
    n_times = 50

    src = read_inverse_operator(fname_inv)['src']
    if kind == 'mixed':
        label_names = ('Left-Cerebellum-Cortex', 'Right-Cerebellum-Cortex')
        src += setup_volume_source_space('sample',
                                         pos=20.,
                                         volume_label=label_names,
                                         subjects_dir=subjects_dir,
                                         add_interpolator=False)
        klass = MixedVectorSourceEstimate
    else:
        klass = VectorSourceEstimate
    if not vector:
        klass = klass._scalar_class
    vertices = [s['vertno'] for s in src]
    n_verts = np.array([len(v) for v in vertices])
    vol_means = np.arange(-1, 1 - len(src), -1)
    vol_means_t = np.repeat(vol_means[:, np.newaxis], n_times, axis=1)

    # get some labels
    labels_lh = read_labels_from_annot('sample',
                                       hemi='lh',
                                       subjects_dir=subjects_dir)
    labels_rh = read_labels_from_annot('sample',
                                       hemi='rh',
                                       subjects_dir=subjects_dir)
    labels = list()
    labels.extend(labels_lh[:5])
    labels.extend(labels_rh[:4])

    n_labels = len(labels)

    label_tcs = dict(mean=np.arange(n_labels)[:, None] *
                     np.ones((n_labels, n_times)))
    label_tcs['max'] = label_tcs['mean']

    # compute the mean with sign flip
    label_tcs['mean_flip'] = np.zeros_like(label_tcs['mean'])
    for i, label in enumerate(labels):
        label_tcs['mean_flip'][i] = i * np.mean(label_sign_flip(
            label, src[:2]))

    # generate some stc's with known data
    stcs = list()
    pad = (((0, 0), (2, 0), (0, 0)), 'constant')
    for i in range(n_stcs):
        data = np.zeros((n_verts.sum(), n_times))
        # set the value of the stc within each label
        for j, label in enumerate(labels):
            if label.hemi == 'lh':
                idx = np.intersect1d(vertices[0], label.vertices)
                idx = np.searchsorted(vertices[0], idx)
            elif label.hemi == 'rh':
                idx = np.intersect1d(vertices[1], label.vertices)
                idx = len(vertices[0]) + np.searchsorted(vertices[1], idx)
            data[idx] = label_tcs['mean'][j]
        for j in range(len(vol_means)):
            offset = n_verts[:2 + j].sum()
            data[offset:offset + n_verts[j]] = vol_means[j]

        if vector:
            # the values it on the Z axis
            data = np.pad(data[:, np.newaxis], *pad)
        this_stc = klass(data, vertices, 0, 1)
        stcs.append(this_stc)

    if vector:
        for key in label_tcs:
            label_tcs[key] = np.pad(label_tcs[key][:, np.newaxis], *pad)
        vol_means_t = np.pad(vol_means_t[:, np.newaxis], *pad)

    # test some invalid inputs
    with pytest.raises(ValueError, match="Invalid value for the 'mode'"):
        extract_label_time_course(stcs, labels, src, mode='notamode')

    # have an empty label
    empty_label = labels[0].copy()
    empty_label.vertices += 1000000
    with pytest.raises(ValueError, match='does not contain any vertices'):
        extract_label_time_course(stcs, empty_label, src)

    # but this works:
    with pytest.warns(RuntimeWarning, match='does not contain any vertices'):
        tc = extract_label_time_course(stcs,
                                       empty_label,
                                       src,
                                       allow_empty=True)
    end_shape = (3, n_times) if vector else (n_times, )
    for arr in tc:
        assert arr.shape == (1 + len(vol_means), ) + end_shape
        assert_array_equal(arr[:1], np.zeros((1, ) + end_shape))
        if len(vol_means):
            assert_array_equal(arr[1:], vol_means_t)

    # test the different modes
    modes = ['mean', 'mean_flip', 'pca_flip', 'max', 'auto']

    for mode in modes:
        if vector and mode not in ('mean', 'max', 'auto'):
            with pytest.raises(ValueError, match='when using a vector'):
                extract_label_time_course(stcs, labels, src, mode=mode)
            continue
        label_tc = extract_label_time_course(stcs, labels, src, mode=mode)
        label_tc_method = [
            stc.extract_label_time_course(labels, src, mode=mode)
            for stc in stcs
        ]
        assert (len(label_tc) == n_stcs)
        assert (len(label_tc_method) == n_stcs)
        for tc1, tc2 in zip(label_tc, label_tc_method):
            assert tc1.shape == (n_labels + len(vol_means), ) + end_shape
            assert tc2.shape == (n_labels + len(vol_means), ) + end_shape
            assert_allclose(tc1, tc2, rtol=1e-8, atol=1e-16)
            if mode == 'auto':
                use_mode = 'mean' if vector else 'mean_flip'
            else:
                use_mode = mode
            # XXX we don't check pca_flip, probably should someday...
            if use_mode in ('mean', 'max', 'mean_flip'):
                assert_array_almost_equal(tc1[:n_labels], label_tcs[use_mode])
            assert_array_almost_equal(tc1[n_labels:], vol_means_t)

    # test label with very few vertices (check SVD conditionals)
    label = Label(vertices=src[0]['vertno'][:2], hemi='lh')
    x = label_sign_flip(label, src[:2])
    assert (len(x) == 2)
    label = Label(vertices=[], hemi='lh')
    x = label_sign_flip(label, src[:2])
    assert (x.size == 0)