Beispiel #1
0
def test_spectral_connectivity():
    """Test frequency-domain connectivity methods"""
    # Use a case known to have no spurious correlations (it would bad if
    # nosetests could randomly fail):
    np.random.seed(0)

    sfreq = 50.
    n_signals = 3
    n_epochs = 10
    n_times = 500

    tmin = 0.
    tmax = (n_times - 1) / sfreq
    data = np.random.randn(n_epochs, n_signals, n_times)
    times_data = np.linspace(tmin, tmax, n_times)
    # simulate connectivity from 5Hz..15Hz
    fstart, fend = 5.0, 15.0
    for i in range(n_epochs):
        data[i, 1, :] = band_pass_filter(data[i, 0, :], sfreq, fstart, fend)
        # add some noise, so the spectrum is not exactly zero
        data[i, 1, :] += 1e-2 * np.random.randn(n_times)

    # First we test some invalid parameters:
    assert_raises(ValueError, spectral_connectivity, data, method='notamethod')
    assert_raises(ValueError, spectral_connectivity, data,
                  mode='notamode')

    # test invalid fmin fmax settings
    assert_raises(ValueError, spectral_connectivity, data, fmin=10,
                  fmax=10 + 0.5 * (sfreq / float(n_times)))
    assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=5)
    assert_raises(ValueError, spectral_connectivity, data, fmin=(0, 11),
                  fmax=(5, 10))
    assert_raises(ValueError, spectral_connectivity, data, fmin=(11,),
                  fmax=(12, 15))

    methods = ['coh', 'imcoh', 'cohy', 'plv', 'ppc', 'pli', 'pli2_unbiased',
               'wpli', 'wpli2_debiased', 'coh']

    modes = ['multitaper', 'fourier', 'cwt_morlet']

    # define some frequencies for cwt
    cwt_frequencies = np.arange(3, 24.5, 1)

    for mode in modes:
        for method in methods:
            if method == 'coh' and mode == 'multitaper':
                # only check adaptive estimation for coh to reduce test time
                check_adaptive = [False, True]
            else:
                check_adaptive = [False]

            if method == 'coh' and mode == 'cwt_morlet':
                # so we also test using an array for num cycles
                cwt_n_cycles = 7. * np.ones(len(cwt_frequencies))
            else:
                cwt_n_cycles = 7.

            for adaptive in check_adaptive:

                if adaptive:
                    mt_bandwidth = 1.
                else:
                    mt_bandwidth = None

                con, freqs, times, n, _ = spectral_connectivity(data,
                        method=method, mode=mode,
                        indices=None, sfreq=sfreq, mt_adaptive=adaptive,
                        mt_low_bias=True, mt_bandwidth=mt_bandwidth,
                        cwt_frequencies=cwt_frequencies,
                        cwt_n_cycles=cwt_n_cycles)

                assert_true(n == n_epochs)
                assert_array_almost_equal(times_data, times)

                if mode == 'multitaper':
                    upper_t = 0.95
                    lower_t = 0.5
                else:
                    # other estimates have higher variance
                    upper_t = 0.8
                    lower_t = 0.75

                # test the simulated signal
                if method == 'coh':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # we see something for zero-lag
                    assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t))

                    if mode != 'cwt_morlet':
                        idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                        assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
                        assert_true(np.all(con[1, 0, idx[1]:] < lower_t))
                elif method == 'cohy':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # imaginary coh will be zero
                    assert_true(np.all(np.imag(con[1, 0, idx[0]:idx[1]])
                                < lower_t))
                    # we see something for zero-lag
                    assert_true(np.all(np.abs(con[1, 0, idx[0]:idx[1]])
                                > upper_t))

                    idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                    if mode != 'cwt_morlet':
                        assert_true(np.all(np.abs(con[1, 0, :idx[0]])
                                    < lower_t))
                        assert_true(np.all(np.abs(con[1, 0, idx[1]:])
                                    < lower_t))
                elif method == 'imcoh':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # imaginary coh will be zero
                    assert_true(np.all(con[1, 0, idx[0]:idx[1]] < lower_t))
                    idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                    assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
                    assert_true(np.all(con[1, 0, idx[1]:] < lower_t))

                # compute same connections using indices and 2 jobs,
                # also add a second method
                indices = tril_indices(n_signals, -1)

                test_methods = (method, _CohEst)
                combo = True if method == 'coh' else False
                stc_data = _stc_gen(data, sfreq, tmin)
                con2, freqs2, times2, n2, _ = spectral_connectivity(stc_data,
                        method=test_methods, mode=mode, indices=indices,
                        sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True,
                        mt_bandwidth=mt_bandwidth, tmin=tmin, tmax=tmax,
                        cwt_frequencies=cwt_frequencies,
                        cwt_n_cycles=cwt_n_cycles, n_jobs=2)

                assert_true(isinstance(con2, list))
                assert_true(len(con2) == 2)

                if method == 'coh':
                    assert_array_almost_equal(con2[0], con2[1])

                con2 = con2[0]  # only keep the first method

                # we get the same result for the probed connections
                assert_array_almost_equal(freqs, freqs2)
                assert_array_almost_equal(con[indices], con2)
                assert_true(n == n2)
                assert_array_almost_equal(times_data, times2)

                # compute same connections for two bands, fskip=1, and f. avg.
                fmin = (5., 15.)
                fmax = (15., 30.)
                con3, freqs3, times3, n3, _ = spectral_connectivity(data,
                        method=method, mode=mode,
                        indices=indices, sfreq=sfreq, fmin=fmin, fmax=fmax,
                        fskip=1, faverage=True, mt_adaptive=adaptive,
                        mt_low_bias=True, mt_bandwidth=mt_bandwidth,
                        cwt_frequencies=cwt_frequencies,
                        cwt_n_cycles=cwt_n_cycles)

                assert_true(isinstance(freqs3, list))
                assert_true(len(freqs3) == len(fmin))
                for i in range(len(freqs3)):
                    assert_true(np.all((freqs3[i] >= fmin[i])
                                       & (freqs3[i] <= fmax[i])))

                # average con2 "manually" and we get the same result
                for i in range(len(freqs3)):
                    freq_idx = np.searchsorted(freqs2, freqs3[i])
                    con2_avg = np.mean(con2[:, freq_idx], axis=1)
                    assert_array_almost_equal(con2_avg, con3[:, i])
def test_spectral_connectivity():
    """Test frequency-domain connectivity methods"""
    # XXX For some reason on 14 Oct 2015 Travis started timing out on this
    # test, so for a quick workaround we will skip it:
    if os.getenv('TRAVIS', 'false') == 'true':
        raise SkipTest('Travis is broken')
    # Use a case known to have no spurious correlations (it would bad if
    # nosetests could randomly fail):
    np.random.seed(0)

    sfreq = 50.
    n_signals = 3
    n_epochs = 8
    n_times = 256

    tmin = 0.
    tmax = (n_times - 1) / sfreq
    data = np.random.randn(n_epochs, n_signals, n_times)
    times_data = np.linspace(tmin, tmax, n_times)
    # simulate connectivity from 5Hz..15Hz
    fstart, fend = 5.0, 15.0
    for i in range(n_epochs):
        with warnings.catch_warnings(record=True):
            warnings.simplefilter('always')
            data[i, 1, :] = band_pass_filter(data[i, 0, :], sfreq, fstart,
                                             fend)
        # add some noise, so the spectrum is not exactly zero
        data[i, 1, :] += 1e-2 * np.random.randn(n_times)

    # First we test some invalid parameters:
    assert_raises(ValueError, spectral_connectivity, data, method='notamethod')
    assert_raises(ValueError, spectral_connectivity, data, mode='notamode')

    # test invalid fmin fmax settings
    assert_raises(ValueError,
                  spectral_connectivity,
                  data,
                  fmin=10,
                  fmax=10 + 0.5 * (sfreq / float(n_times)))
    assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=5)
    assert_raises(ValueError,
                  spectral_connectivity,
                  data,
                  fmin=(0, 11),
                  fmax=(5, 10))
    assert_raises(ValueError,
                  spectral_connectivity,
                  data,
                  fmin=(11, ),
                  fmax=(12, 15))

    methods = [
        'coh', 'cohy', 'imcoh',
        [
            'plv', 'ppc', 'pli', 'pli2_unbiased', 'wpli', 'wpli2_debiased',
            'coh'
        ]
    ]

    modes = ['multitaper', 'fourier', 'cwt_morlet']

    # define some frequencies for cwt
    cwt_frequencies = np.arange(3, 24.5, 1)

    for mode in modes:
        for method in methods:
            if method == 'coh' and mode == 'multitaper':
                # only check adaptive estimation for coh to reduce test time
                check_adaptive = [False, True]
            else:
                check_adaptive = [False]

            if method == 'coh' and mode == 'cwt_morlet':
                # so we also test using an array for num cycles
                cwt_n_cycles = 7. * np.ones(len(cwt_frequencies))
            else:
                cwt_n_cycles = 7.

            for adaptive in check_adaptive:

                if adaptive:
                    mt_bandwidth = 1.
                else:
                    mt_bandwidth = None

                con, freqs, times, n, _ = spectral_connectivity(
                    data,
                    method=method,
                    mode=mode,
                    indices=None,
                    sfreq=sfreq,
                    mt_adaptive=adaptive,
                    mt_low_bias=True,
                    mt_bandwidth=mt_bandwidth,
                    cwt_frequencies=cwt_frequencies,
                    cwt_n_cycles=cwt_n_cycles)

                assert_true(n == n_epochs)
                assert_array_almost_equal(times_data, times)

                if mode == 'multitaper':
                    upper_t = 0.95
                    lower_t = 0.5
                else:
                    # other estimates have higher variance
                    upper_t = 0.8
                    lower_t = 0.75

                # test the simulated signal
                if method == 'coh':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # we see something for zero-lag
                    assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t))

                    if mode != 'cwt_morlet':
                        idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                        assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
                        assert_true(np.all(con[1, 0, idx[1]:] < lower_t))
                elif method == 'cohy':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # imaginary coh will be zero
                    assert_true(
                        np.all(np.imag(con[1, 0, idx[0]:idx[1]]) < lower_t))
                    # we see something for zero-lag
                    assert_true(
                        np.all(np.abs(con[1, 0, idx[0]:idx[1]]) > upper_t))

                    idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                    if mode != 'cwt_morlet':
                        assert_true(
                            np.all(np.abs(con[1, 0, :idx[0]]) < lower_t))
                        assert_true(
                            np.all(np.abs(con[1, 0, idx[1]:]) < lower_t))
                elif method == 'imcoh':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # imaginary coh will be zero
                    assert_true(np.all(con[1, 0, idx[0]:idx[1]] < lower_t))
                    idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                    assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
                    assert_true(np.all(con[1, 0, idx[1]:] < lower_t))

                # compute same connections using indices and 2 jobs
                indices = tril_indices(n_signals, -1)

                if not isinstance(method, list):
                    test_methods = (method, _CohEst)
                else:
                    test_methods = method

                stc_data = _stc_gen(data, sfreq, tmin)
                con2, freqs2, times2, n2, _ = spectral_connectivity(
                    stc_data,
                    method=test_methods,
                    mode=mode,
                    indices=indices,
                    sfreq=sfreq,
                    mt_adaptive=adaptive,
                    mt_low_bias=True,
                    mt_bandwidth=mt_bandwidth,
                    tmin=tmin,
                    tmax=tmax,
                    cwt_frequencies=cwt_frequencies,
                    cwt_n_cycles=cwt_n_cycles,
                    n_jobs=2)

                assert_true(isinstance(con2, list))
                assert_true(len(con2) == len(test_methods))

                if method == 'coh':
                    assert_array_almost_equal(con2[0], con2[1])

                if not isinstance(method, list):
                    con2 = con2[0]  # only keep the first method

                    # we get the same result for the probed connections
                    assert_array_almost_equal(freqs, freqs2)
                    assert_array_almost_equal(con[indices], con2)
                    assert_true(n == n2)
                    assert_array_almost_equal(times_data, times2)
                else:
                    # we get the same result for the probed connections
                    assert_true(len(con) == len(con2))
                    for c, c2 in zip(con, con2):
                        assert_array_almost_equal(freqs, freqs2)
                        assert_array_almost_equal(c[indices], c2)
                        assert_true(n == n2)
                        assert_array_almost_equal(times_data, times2)

                # compute same connections for two bands, fskip=1, and f. avg.
                fmin = (5., 15.)
                fmax = (15., 30.)
                con3, freqs3, times3, n3, _ = spectral_connectivity(
                    data,
                    method=method,
                    mode=mode,
                    indices=indices,
                    sfreq=sfreq,
                    fmin=fmin,
                    fmax=fmax,
                    fskip=1,
                    faverage=True,
                    mt_adaptive=adaptive,
                    mt_low_bias=True,
                    mt_bandwidth=mt_bandwidth,
                    cwt_frequencies=cwt_frequencies,
                    cwt_n_cycles=cwt_n_cycles)

                assert_true(isinstance(freqs3, list))
                assert_true(len(freqs3) == len(fmin))
                for i in range(len(freqs3)):
                    assert_true(
                        np.all((freqs3[i] >= fmin[i])
                               & (freqs3[i] <= fmax[i])))

                # average con2 "manually" and we get the same result
                if not isinstance(method, list):
                    for i in range(len(freqs3)):
                        freq_idx = np.searchsorted(freqs2, freqs3[i])
                        con2_avg = np.mean(con2[:, freq_idx], axis=1)
                        assert_array_almost_equal(con2_avg, con3[:, i])
                else:
                    for j in range(len(con2)):
                        for i in range(len(freqs3)):
                            freq_idx = np.searchsorted(freqs2, freqs3[i])
                            con2_avg = np.mean(con2[j][:, freq_idx], axis=1)
                            assert_array_almost_equal(con2_avg, con3[j][:, i])
Beispiel #3
0
def plot_connectivity_circle_cvu(con, nodes_numberless, indices=None, 
	n_lines=10000, node_colors=None, colormap='YlOrRd', fig=None, reqrois=[],
	suppress_extra_rois=False,
	node_angles=None, node_width=None, facecolor='black',
	textcolor='white', node_edgecolor='black',linewidth=1.5,
	vmin=None,vmax=None, colorbar=False, title=None,
	fontsize_names='auto'):
	"""Visualize connectivity as a circular graph.

Note: This code is based on the circle graph example by Nicolas P. Rougier
http://www.loria.fr/~rougier/coding/recipes.html

This function replicates functionality from MNE python, by Martin Luessi
and others. Many changes are made from the MNE python version.

Parameters
----------
con : array
Connectivity scores. Can be a square matrix, or a 1D array. If a 1D
array is provided, "indices" has to be used to define the connection
indices.

nodes_numberless : list of str
Node names. The order corresponds to the order in con.

indices : tuple of arrays | None
Two arrays with indices of connections for which the connections
strenghts are defined in con. Only needed if con is a 1D array.

n_lines : int | None
If not None, only the n_lines strongest connections (strenght=abs(con))
are drawn.

node_angles : array, shape=(len(nodes_numberless,)) | None
Array with node positions in degrees. If None, the nodes are equally
spaced on the circle. See mne.viz.circular_layout.

node_width : float | None
Width of each node in degrees. If None, "360. / len(nodes_numberless)" is
used.

node_colors : list of tuples | list of str
List with the color to use for each node. If fewer colors than nodes
are provided, the colors will be repeated. Any color supported by
matplotlib can be used, e.g., RGBA tuples, named colors.

facecolor : str
Color to use for background. See matplotlib.colors.

textcolor : str
Color to use for text. See matplotlib.colors.

node_edgecolor : str
Color to use for lines around nodes. See matplotlib.colors.

linewidth : float
Line width to use for connections.

colormap : str
Colormap to use for coloring the connections.

vmin : float | None
Minimum value for colormap. If None, it is determined automatically.

vmax : float | None
Maximum value for colormap. If None, it is determined automatically.

colorbar : bool
Display a colorbar or not.

title : str
The figure title.

fontsize_names : int | str
The fontsize for the node labels. If 'auto', the program attempts to determine
a reasonable size. 'auto' is the default value.

Returns
-------
fig : instance of pyplot.Figure
The figure handle.
"""
	n_nodes = len(nodes_numberless)

	#start_hemi = nodes_numberless[0][:3]
	#end_hemi = nodes_numberless[-1][:3]	
	#n_starthemi = sum(map(lambda lb:lb[:3]==start_hemi,nodes_numberless))
	#n_endhemi = sum(map(lambda lb:lb[:3]==end_hemi,nodes_numberless))

	if node_angles is not None:
		if len(node_angles) != n_nodes:
			raise ValueError('node_angles has to be the same length '
							 'as nodes_numberless')
		# convert it to radians
		node_angles = node_angles * np.pi / 180
	else:
		# uniform layout on unit circle
		node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)

	if node_width is None:
		node_width = 2 * np.pi / n_nodes
	else:
		node_width = node_width * np.pi / 180

	# handle 1D and 2D connectivity information
	if con.ndim == 1:
		if indices is None:
			raise ValueError('indices has to be provided if con.ndim == 1')
				#we use 1D indices
	elif con.ndim == 2:
		if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
			raise ValueError('con has to be 1D or a square matrix')
		# we use the lower-triangular part
		indices = tril_indices(n_nodes, -1)
		con = con[indices]
	else:
		raise ValueError('con has to be 1D or a square matrix')

	# get the colormap
	if isinstance(colormap, basestring):
		colormap = pl.get_cmap(colormap)

	# Make figure background the same colors as axes
	if fig==None:
		fig = pl.figure(figsize=(5, 5), facecolor=facecolor)
	else:
		fig = pl.figure(num=fig.number)
		
	# Use a polar axes
	axes = pl.subplot(111, polar=True, axisbg=facecolor)
	#else:
		# Use the first axis already in the figure
		#axes = fig.get_axes()[0]

	# No ticks, we'll put our own
	pl.xticks([])
	pl.yticks([])

	# Set y axes limit
	pl.ylim(0, 10)
	#pl.ylim(ymin=0)

	# Draw lines between connected nodes, only draw the strongest connections
	if n_lines is not None and len(con) > n_lines:
		con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
	else:
		con_thresh = 0.

	# get the connections which we are drawing and sort by connection strength
	# this will allow us to draw the strongest connections first
	con_abs = np.abs(con)
	con_draw_idx = np.where(con_abs >= con_thresh)[0]

	con = con[con_draw_idx]
	con_abs = con_abs[con_draw_idx]
	indices = [ind[con_draw_idx] for ind in indices]

	# input is already sorted
	#sort_idx = np.argsort(con_abs)
	#con_abs = con_abs[sort_idx]
	#con = con[sort_idx]
	#indices = [ind[sort_idx] for ind in indices]

	# Get vmin vmax for color scaling
	if np.size(con)>0:
		if vmin is None:
			vmin = np.min(con[np.abs(con) >= con_thresh])
		if vmax is None:
			vmax = np.max(con)
		vrange = vmax - vmin

	# We want o add some "noise" to the start and end position of the
	# edges: We modulate the noise with the number of connections of the
	# node and the connection strength, such that the strongest connections
	# are closer to the node center
	nodes_n_con = np.zeros((n_nodes), dtype=np.int)
	for i, j in zip(indices[0], indices[1]):
		nodes_n_con[i] += 1
		nodes_n_con[j] += 1

	# initalize random number generator so plot is reproducible
	rng = np.random.mtrand.RandomState(seed=0)

	n_con = len(indices[0])
	noise_max = 0.25 * node_width
	start_noise = rng.uniform(-noise_max, noise_max, n_con)
	end_noise = rng.uniform(-noise_max, noise_max, n_con)

	nodes_n_con_seen = np.zeros_like(nodes_n_con)
	for i, (start, end) in enumerate(zip(indices[0], indices[1])):
		nodes_n_con_seen[start] += 1
		nodes_n_con_seen[end] += 1

		start_noise[i] *= ((nodes_n_con[start] - nodes_n_con_seen[start])
						   / nodes_n_con[start])
		end_noise[i] *= ((nodes_n_con[end] - nodes_n_con_seen[end])
						 / nodes_n_con[end])

	# scale connectivity for colormap (vmin<=>0, vmax<=>1)
	if np.size(con)>0:
		con_val_scaled = (con - vmin) / vrange

	# Finally, we draw the connections
	for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
		# Start point
		t0, r0 = node_angles[i], 7

		# End point
		t1, r1 = node_angles[j], 7

		# Some noise in start and end point
		t0 += start_noise[pos]
		t1 += end_noise[pos]

		verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
		codes = [m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4,
				 m_path.Path.LINETO]
		path = m_path.Path(verts, codes)

		color = colormap(con_val_scaled[pos])

		# Actual line
		patch = m_patches.PathPatch(path, fill=False, edgecolor=color,
									linewidth=linewidth, alpha=1.)
		axes.add_patch(patch)

	# Draw ring with colored nodes
	#radii = np.ones(n_nodes) * 8
	radii=np.ones(n_nodes)-.2
	bars = axes.bar(node_angles, radii, width=node_width, bottom=7.2,
					edgecolor=node_edgecolor, linewidth=0, facecolor='.9',
					align='center',zorder=10)

	for bar, color in zip(bars, node_colors):
		bar.set_facecolor(color)

	# Draw node labels

	#basic idea -- check for "too close" pairs.  too close is pi/50
	#remove smallest "too close" pairs.  if multiple tied within a segment,
	#remove pairs at equal spacing.
	#calculate each segment individually and find the extent of the segment.

	#TODO this parameter, too_close, could be modified and adjusted for
	#a variety of sizes if ever the circle could be panned (or if it were
	#merely made bigger).  determining the proper value is a matter of 
	#userspace testing
	too_close = np.pi/50

	# get angles for text placement
	text_angles = get_labels_avg_idx(nodes_numberless,n_nodes,frac=1,pad=np.pi/400)

	#print reqrois
	segments = get_tooclose_segments(text_angles,too_close,reqrois)
	
	for segment in segments:
		prune_segment(text_angles,segment,too_close)

	#print suppress_extra_rois, len(reqrois)
	if suppress_extra_rois and len(reqrois)>0:
		for name in text_angles.keys():
			if name not in reqrois:
				del text_angles[name]

		if fontsize_names=='auto':
			fontsize_names=10
				
	#TODO segments with many guaranteed ROIs are potentially spatially skewed
	#this is probably not worth fixing

	#now calculate how many pairs must be removed and remove them at equal
	#spacing.  there should be no more than theta/(n-1) >= pi/50 pairs where
	#theta is the extent and n is the number of pairs.
	#n-1 is used because each segment holds one item by default
	
	#for angles,hemi in [(text_angles_sh,start_hemi),(text_angles_eh,end_hemi)]:
	#	for name in angles:

	if fontsize_names=='auto':
		fontsize_names=8

	for name in text_angles:
		angle_rad = text_angles[name]
		#if hemi is end_hemi:
		#	angle_rad+=np.pi
		angle_deg = 180*angle_rad/np.pi
		if angle_deg >= 270 or angle_deg < 90:
			ha = 'left'
		else:
			# Flip the label, so text is always upright
			angle_deg += 180
			ha = 'right'

		name_nonum=name.strip('1234567890')
		hemi=''
		axes.text(angle_rad, 8.2, hemi+name_nonum, size=fontsize_names,
			rotation=angle_deg,
			rotation_mode='anchor', horizontalalignment=ha,
			verticalalignment='center', color=textcolor)

	if title is not None:
		pl.subplots_adjust(left=0.2, bottom=0.2, right=0.8, top=0.75)
		pl.figtext(0.03, 0.95, title, color=textcolor, fontsize=14)
	else:
		pl.subplots_adjust(left=0.2, bottom=0.2, right=0.8, top=0.8)

	if colorbar:
		sm = pl.cm.ScalarMappable(cmap=colormap,
								  norm=pl.normalize(vmin=vmin, vmax=vmax))
		sm.set_array(np.linspace(vmin, vmax))
		ax = fig.add_axes([.92, 0.03, .015, .25])
		cb = fig.colorbar(sm, cax=ax)
		cb_yticks = pl.getp(cb.ax.axes, 'yticklabels')
		pl.setp(cb_yticks, color=textcolor)
	
	return fig
Beispiel #4
0
def test_spectral_connectivity():
    """Test frequency-domain connectivity methods"""

    # First we test some invalid parameters:
    assert_raises(ValueError, spectral_connectivity, data, method='notamethod')
    assert_raises(ValueError, spectral_connectivity, data, mode='notamode')

    # test invalid fmin fmax settings
    assert_raises(ValueError,
                  spectral_connectivity,
                  data,
                  fmin=10,
                  fmax=10 + 0.5 * (sfreq / float(n_times)))
    assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=5)
    assert_raises(ValueError,
                  spectral_connectivity,
                  data,
                  fmin=(0, 11),
                  fmax=(5, 10))
    assert_raises(ValueError,
                  spectral_connectivity,
                  data,
                  fmin=(11, ),
                  fmax=(12, 15))

    methods = [
        'coh', 'imcoh', 'cohy', 'plv', 'ppc', 'pli', 'pli2_unbiased', 'wpli',
        'wpli2_debiased', 'coh'
    ]

    modes = ['multitaper', 'fourier', 'cwt_morlet']

    # define some frequencies for cwt
    cwt_frequencies = np.arange(3, 24.5, 1)

    for mode in modes:
        for method in methods:
            if method == 'coh' and mode == 'multitaper':
                # only check adaptive estimation for coh to reduce test time
                check_adaptive = [False, True]
            else:
                check_adaptive = [False]

            if method == 'coh' and mode == 'cwt_morlet':
                # so we also test using an array for num cycles
                cwt_n_cycles = 7. * np.ones(len(cwt_frequencies))
            else:
                cwt_n_cycles = 7.

            for adaptive in check_adaptive:

                if adaptive:
                    mt_bandwidth = 1.
                else:
                    mt_bandwidth = None

                con, freqs, times, n, _ = spectral_connectivity(
                    data,
                    method=method,
                    mode=mode,
                    indices=None,
                    sfreq=sfreq,
                    mt_adaptive=adaptive,
                    mt_low_bias=True,
                    mt_bandwidth=mt_bandwidth,
                    cwt_frequencies=cwt_frequencies,
                    cwt_n_cycles=cwt_n_cycles)

                assert_true(n == n_epochs)
                assert_array_almost_equal(times_data, times)

                if mode == 'multitaper':
                    upper_t = 0.95
                    lower_t = 0.5
                else:
                    # other estimates have higher variance
                    upper_t = 0.8
                    lower_t = 0.75

                # test the simulated signal
                if method == 'coh':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # we see something for zero-lag
                    assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t))

                    if mode != 'cwt_morlet':
                        idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                        assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
                        assert_true(np.all(con[1, 0, idx[1]:] < lower_t))
                elif method == 'cohy':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # imaginary coh will be zero
                    assert_true(
                        np.all(np.imag(con[1, 0, idx[0]:idx[1]]) < lower_t))
                    # we see something for zero-lag
                    assert_true(
                        np.all(np.abs(con[1, 0, idx[0]:idx[1]]) > upper_t))

                    idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                    if mode != 'cwt_morlet':
                        assert_true(
                            np.all(np.abs(con[1, 0, :idx[0]]) < lower_t))
                        assert_true(
                            np.all(np.abs(con[1, 0, idx[1]:]) < lower_t))
                elif method == 'imcoh':
                    idx = np.searchsorted(freqs, (fstart + 1, fend - 1))
                    # imaginary coh will be zero
                    assert_true(np.all(con[1, 0, idx[0]:idx[1]] < lower_t))
                    idx = np.searchsorted(freqs, (fstart - 1, fend + 1))
                    assert_true(np.all(con[1, 0, :idx[0]] < lower_t))
                    assert_true(np.all(con[1, 0, idx[1]:] < lower_t))

                # compute same connections using indices and 2 jobs,
                # also add a second method
                indices = tril_indices(n_signals, -1)

                test_methods = (method, _CohEst)
                combo = True if method == 'coh' else False
                stc_data = _stc_gen(data, sfreq, tmin)
                con2, freqs2, times2, n2, _ = spectral_connectivity(
                    stc_data,
                    method=test_methods,
                    mode=mode,
                    indices=indices,
                    sfreq=sfreq,
                    mt_adaptive=adaptive,
                    mt_low_bias=True,
                    mt_bandwidth=mt_bandwidth,
                    tmin=tmin,
                    tmax=tmax,
                    cwt_frequencies=cwt_frequencies,
                    cwt_n_cycles=cwt_n_cycles,
                    n_jobs=2)

                assert_true(isinstance(con2, list))
                assert_true(len(con2) == 2)

                if method == 'coh':
                    assert_array_almost_equal(con2[0], con2[1])

                con2 = con2[0]  # only keep the first method

                # we get the same result for the probed connections
                assert_array_almost_equal(freqs, freqs2)
                assert_array_almost_equal(con[indices], con2)
                assert_true(n == n2)
                assert_array_almost_equal(times_data, times2)

                # compute same connections for two bands, fskip=1, and f. avg.
                fmin = (5., 15.)
                fmax = (15., 30.)
                con3, freqs3, times3, n3, _ = spectral_connectivity(
                    data,
                    method=method,
                    mode=mode,
                    indices=indices,
                    sfreq=sfreq,
                    fmin=fmin,
                    fmax=fmax,
                    fskip=1,
                    faverage=True,
                    mt_adaptive=adaptive,
                    mt_low_bias=True,
                    mt_bandwidth=mt_bandwidth,
                    cwt_frequencies=cwt_frequencies,
                    cwt_n_cycles=cwt_n_cycles)

                assert_true(isinstance(freqs3, list))
                assert_true(len(freqs3) == len(fmin))
                for i in range(len(freqs3)):
                    assert_true(
                        np.all((freqs3[i] >= fmin[i])
                               & (freqs3[i] <= fmax[i])))

                # average con2 "manually" and we get the same result
                for i in range(len(freqs3)):
                    freq_idx = np.searchsorted(freqs2, freqs3[i])
                    con2_avg = np.mean(con2[:, freq_idx], axis=1)
                    assert_array_almost_equal(con2_avg, con3[:, i])
Beispiel #5
0
def plot_connectivity_circle_cvu(con,
                                 nodes_numberless,
                                 indices=None,
                                 n_lines=10000,
                                 node_colors=None,
                                 colormap='YlOrRd',
                                 fig=None,
                                 reqrois=[],
                                 suppress_extra_rois=False,
                                 node_angles=None,
                                 node_width=None,
                                 facecolor='black',
                                 textcolor='white',
                                 node_edgecolor='black',
                                 linewidth=1.5,
                                 vmin=None,
                                 vmax=None,
                                 colorbar=False,
                                 title=None,
                                 fontsize_names='auto',
                                 bilateral_symmetry=True):
    """Visualize connectivity as a circular graph.

Note: This code is originally taken from public open-source
examples in matplotlib by Nicolas P. Rougier. It was adapted for use in
MNE python, primarily by Martin Luessi, but also by all the other contributors
to MNE python.

There are some differences between the current version and the MNE python
version. Most importantly, the current version offers less flexibility of the
layout of the plot and has algorithms to determine this layout automatically
given the ordering of the CVU dataset. Each hemisphere takes up roughly half
the space and the left hemisphere is always on the left side of the plot. Then
there is a very complex and poorly documented algorithm to randomly suppress 
extra label names so that all of the label names that result are readable.
Note that the suppression of label names can be overwritten in the GUI although
it is quite effortful, typically it is recommended to do image
postprocessing instead.

Parameters
----------
con : array
Connectivity scores. Can be a square matrix, or a 1D array. If a 1D
array is provided, "indices" has to be used to define the connection
indices.

nodes_numberless : list of str
Node names. The order corresponds to the order in con.

indices : tuple of arrays | None
Two arrays with indices of connections for which the connections
strenghts are defined in con. Only needed if con is a 1D array.

n_lines : int | None
If not None, only the n_lines strongest connections (strenght=abs(con))
are drawn.

node_angles : array, shape=(len(nodes_numberless,)) | None
Array with node positions in degrees. If None, the nodes are equally
spaced on the circle. See mne.viz.circular_layout.

node_width : float | None
Width of each node in degrees. If None, "360. / len(nodes_numberless)" is
used.

node_colors : list of tuples | list of str
List with the color to use for each node. If fewer colors than nodes
are provided, the colors will be repeated. Any color supported by
matplotlib can be used, e.g., RGBA tuples, named colors.

facecolor : str
Color to use for background. See matplotlib.colors.

textcolor : str
Color to use for text. See matplotlib.colors.

node_edgecolor : str
Color to use for lines around nodes. See matplotlib.colors.

linewidth : float
Line width to use for connections.

colormap : str
Colormap to use for coloring the connections.

vmin : float | None
Minimum value for colormap. If None, it is determined automatically.

vmax : float | None
Maximum value for colormap. If None, it is determined automatically.

colorbar : bool
Display a colorbar or not.

title : str
The figure title.

fontsize_names : int | str
The fontsize for the node labels. If 'auto', the program attempts to determine
a reasonable size. 'auto' is the default value.

Returns
-------
fig : instance of pyplot.Figure
The figure handle.
"""
    n_nodes = len(nodes_numberless)

    #reverse the lower hemisphere so that the circle is bilaterally symmetric
    start_hemi = 'l'
    first_hemi = nodes_numberless[0][0]

    def find_pivot(ls, item):
        for i, l in enumerate(ls):
            if l[0] != item:
                return i

    hemi_pivot = find_pivot(nodes_numberless, first_hemi)

    if bilateral_symmetry:
        if start_hemi == first_hemi:
            nodes_numberless = (nodes_numberless[:hemi_pivot] +
                                nodes_numberless[:hemi_pivot - 1:-1])

            node_colors = (node_colors[:hemi_pivot] +
                           node_colors[:hemi_pivot - 1:-1])

            if indices.size > 0:
                indices = indices.copy()
                indices[np.where(indices >= hemi_pivot)] = (
                    n_nodes - 1 + hemi_pivot -
                    indices[np.where(indices >= hemi_pivot)])
        else:
            nodes_numberless = (nodes_numberless[hemi_pivot:] +
                                nodes_numberless[hemi_pivot - 1::-1])

            node_colors = (node_colors[hemi_pivot:] +
                           node_colors[hemi_pivot - 1::-1])

            if indices.size > 0:
                indices_x = indices.copy()
                indices_x[np.where(indices < hemi_pivot)] = (
                    n_nodes - 1 - indices[np.where(indices < hemi_pivot)])
                indices_x[np.where(indices >= hemi_pivot)] = (
                    -hemi_pivot + indices[np.where(indices >= hemi_pivot)])
                indices = indices_x
                del indices_x
    #if bilateral symmetry is turned off, then still put the
    #left hemisphere on the left side
    else:
        if start_hemi != first_hemi:
            nodes_numberless = (nodes_numberless[hemi_pivot:] +
                                nodes_numberless[:hemi_pivot])

            node_colors = (node_colors[hemi_pivot:] + node_colors[:hemi_pivot])

            if indices.size > 0:
                indices_x = indices.copy()
                indices_x[np.where(indices < hemi_pivot)] = (
                    hemi_pivot + indices[np.where(indices < hemi_pivot)])
                indices_x[np.where(indices >= hemi_pivot)] = (
                    -hemi_pivot + indices[np.where(indices >= hemi_pivot)])
                indices = indices_x
                del indices_x

    if node_angles is not None:
        if len(node_angles) != n_nodes:
            raise ValueError('node_angles has to be the same length '
                             'as nodes_numberless')
        # convert it to radians
        node_angles = node_angles * np.pi / 180
    else:
        # uniform layout on unit circle
        node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
    node_angles += np.pi / 2

    if node_width is None:
        node_width = 2 * np.pi / n_nodes
    else:
        node_width = node_width * np.pi / 180

    # handle 1D and 2D connectivity information
    if con.ndim == 1:
        if indices is None:
            raise ValueError('indices has to be provided if con.ndim == 1')
            #we use 1D indices
    elif con.ndim == 2:
        if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
            raise ValueError('con has to be 1D or a square matrix')
        # we use the lower-triangular part
        indices = tril_indices(n_nodes, -1)
        con = con[indices]
    else:
        raise ValueError('con has to be 1D or a square matrix')

    # get the colormap
    if isinstance(colormap, CustomColormap):
        colormap = colormap._get__pl()
    elif isinstance(colormap, basestring):
        colormap = pl.get_cmap(colormap)

    # Make figure background the same colors as axes
    if fig == None:
        fig = pl.figure(figsize=(5, 5), facecolor=facecolor)
    else:
        fig = pl.figure(num=fig.number)

    # Use a polar axes
    axes = pl.subplot(111, polar=True, axisbg=facecolor)
    #else:
    # Use the first axis already in the figure
    #axes = fig.get_axes()[0]

    # No ticks, we'll put our own
    pl.xticks([])
    pl.yticks([])

    # Set y axes limit
    pl.ylim(0, 10)
    #pl.ylim(ymin=0)

    axes.spines['polar'].set_visible(False)

    # Draw lines between connected nodes, only draw the strongest connections
    if n_lines is not None and len(con) > n_lines:
        con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
    else:
        con_thresh = 0.

    # get the connections which we are drawing and sort by connection strength
    # this will allow us to draw the strongest connections first
    con_abs = np.abs(con)
    con_draw_idx = np.where(con_abs >= con_thresh)[0]

    con = con[con_draw_idx]
    con_abs = con_abs[con_draw_idx]
    indices = [ind[con_draw_idx] for ind in indices]

    # input is already sorted
    #sort_idx = np.argsort(con_abs)
    #con_abs = con_abs[sort_idx]
    #con = con[sort_idx]
    #indices = [ind[sort_idx] for ind in indices]

    # Get vmin vmax for color scaling
    if np.size(con) > 0:
        if vmin is None:
            vmin = np.min(con[np.abs(con) >= con_thresh])
        if vmax is None:
            vmax = np.max(con)
        vrange = vmax - vmin

    # We want o add some "noise" to the start and end position of the
    # edges: We modulate the noise with the number of connections of the
    # node and the connection strength, such that the strongest connections
    # are closer to the node center
    nodes_n_con = np.zeros((n_nodes), dtype=np.int)
    for i, j in zip(indices[0], indices[1]):
        nodes_n_con[i] += 1
        nodes_n_con[j] += 1

    # initalize random number generator so plot is reproducible
    rng = np.random.mtrand.RandomState(seed=0)

    n_con = len(indices[0])
    noise_max = 0.25 * node_width
    start_noise = rng.uniform(-noise_max, noise_max, n_con)
    end_noise = rng.uniform(-noise_max, noise_max, n_con)

    nodes_n_con_seen = np.zeros_like(nodes_n_con)
    for i, (start, end) in enumerate(zip(indices[0], indices[1])):
        nodes_n_con_seen[start] += 1
        nodes_n_con_seen[end] += 1

        start_noise[i] *= ((nodes_n_con[start] - nodes_n_con_seen[start]) /
                           nodes_n_con[start])
        end_noise[i] *= ((nodes_n_con[end] - nodes_n_con_seen[end]) /
                         nodes_n_con[end])

    # scale connectivity for colormap (vmin<=>0, vmax<=>1)
    if np.size(con) > 0:
        con_val_scaled = (con - vmin) / vrange

    # Finally, we draw the connections
    for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
        # Start point
        t0, r0 = node_angles[i], 7

        # End point
        t1, r1 = node_angles[j], 7

        # Some noise in start and end point
        t0 += start_noise[pos]
        t1 += end_noise[pos]

        verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
        codes = [
            m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4,
            m_path.Path.LINETO
        ]
        path = m_path.Path(verts, codes)

        color = colormap(con_val_scaled[pos])

        # Actual line
        patch = m_patches.PathPatch(path,
                                    fill=False,
                                    edgecolor=color,
                                    linewidth=linewidth,
                                    alpha=1.)
        axes.add_patch(patch)

    # Draw ring with colored nodes
    #radii = np.ones(n_nodes) * 8
    radii = np.ones(n_nodes) - .2
    bars = axes.bar(node_angles,
                    radii,
                    width=node_width,
                    bottom=7.2,
                    edgecolor=node_edgecolor,
                    linewidth=0,
                    facecolor='.9',
                    align='center',
                    zorder=10)

    for bar, color in zip(bars, node_colors):
        bar.set_facecolor(color)

    # Draw node labels

    #basic idea -- check for "too close" pairs.  too close is pi/50
    #remove smallest "too close" pairs.  if multiple tied within a segment,
    #remove pairs at equal spacing.
    #calculate each segment individually and find the extent of the segment.

    #TODO this parameter, too_close, could be modified and adjusted for
    #a variety of sizes if ever the circle could be panned (or if it were
    #merely made bigger).  determining the proper value is a matter of
    #userspace testing
    too_close = np.pi / 50

    # get angles for text placement
    text_angles = get_labels_avg_idx(nodes_numberless,
                                     n_nodes,
                                     frac=1,
                                     pad=np.pi / 400)

    #print reqrois
    segments = get_tooclose_segments(text_angles, too_close, reqrois)

    for segment in segments:
        prune_segment(text_angles, segment, too_close)

    #print suppress_extra_rois, len(reqrois)
    if suppress_extra_rois and len(reqrois) > 0:
        for name in text_angles.keys():
            if name not in reqrois:
                del text_angles[name]

        if fontsize_names == 'auto':
            fontsize_names = 10

    #TODO segments with many guaranteed ROIs are potentially spatially skewed
    #this is probably not worth fixing

    #now calculate how many pairs must be removed and remove them at equal
    #spacing.  there should be no more than theta/(n-1) >= pi/50 pairs where
    #theta is the extent and n is the number of pairs.
    #n-1 is used because each segment holds one item by default

    #for angles,hemi in [(text_angles_sh,start_hemi),(text_angles_eh,end_hemi)]:
    #	for name in angles:

    if fontsize_names == 'auto':
        fontsize_names = 8

    for name in text_angles:
        angle_rad = text_angles[name] + np.pi / 2
        #if hemi is end_hemi:
        #	angle_rad+=np.pi
        angle_deg = 180 * angle_rad / np.pi
        if angle_deg >= 270 or angle_deg < 90:
            ha = 'left'
        else:
            # Flip the label, so text is always upright
            angle_deg += 180
            ha = 'right'

        name_nonum = name.strip('1234567890')
        hemi = ''
        axes.text(angle_rad,
                  8.2,
                  hemi + name_nonum,
                  size=fontsize_names,
                  rotation=angle_deg,
                  rotation_mode='anchor',
                  horizontalalignment=ha,
                  verticalalignment='center',
                  color=textcolor)

    if title is not None:
        pl.subplots_adjust(left=0.2, bottom=0.2, right=0.8, top=0.75)
        pl.figtext(0.03, 0.95, title, color=textcolor, fontsize=14)
    else:
        pl.subplots_adjust(left=0.2, bottom=0.2, right=0.8, top=0.8)

    if colorbar:
        sm = pl.cm.ScalarMappable(cmap=colormap,
                                  norm=pl.normalize(vmin=vmin, vmax=vmax))
        sm.set_array(np.linspace(vmin, vmax))
        ax = fig.add_axes([.92, 0.03, .015, .25])
        cb = fig.colorbar(sm, cax=ax)
        cb_yticks = pl.getp(cb.ax.axes, 'yticklabels')
        pl.setp(cb_yticks, color=textcolor)

    return fig, node_angles