def test_spectral_connectivity(): """Test frequency-domain connectivity methods""" # Use a case known to have no spurious correlations (it would bad if # nosetests could randomly fail): np.random.seed(0) sfreq = 50. n_signals = 3 n_epochs = 10 n_times = 500 tmin = 0. tmax = (n_times - 1) / sfreq data = np.random.randn(n_epochs, n_signals, n_times) times_data = np.linspace(tmin, tmax, n_times) # simulate connectivity from 5Hz..15Hz fstart, fend = 5.0, 15.0 for i in range(n_epochs): data[i, 1, :] = band_pass_filter(data[i, 0, :], sfreq, fstart, fend) # add some noise, so the spectrum is not exactly zero data[i, 1, :] += 1e-2 * np.random.randn(n_times) # First we test some invalid parameters: assert_raises(ValueError, spectral_connectivity, data, method='notamethod') assert_raises(ValueError, spectral_connectivity, data, mode='notamode') # test invalid fmin fmax settings assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=10 + 0.5 * (sfreq / float(n_times))) assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=5) assert_raises(ValueError, spectral_connectivity, data, fmin=(0, 11), fmax=(5, 10)) assert_raises(ValueError, spectral_connectivity, data, fmin=(11,), fmax=(12, 15)) methods = ['coh', 'imcoh', 'cohy', 'plv', 'ppc', 'pli', 'pli2_unbiased', 'wpli', 'wpli2_debiased', 'coh'] modes = ['multitaper', 'fourier', 'cwt_morlet'] # define some frequencies for cwt cwt_frequencies = np.arange(3, 24.5, 1) for mode in modes: for method in methods: if method == 'coh' and mode == 'multitaper': # only check adaptive estimation for coh to reduce test time check_adaptive = [False, True] else: check_adaptive = [False] if method == 'coh' and mode == 'cwt_morlet': # so we also test using an array for num cycles cwt_n_cycles = 7. * np.ones(len(cwt_frequencies)) else: cwt_n_cycles = 7. for adaptive in check_adaptive: if adaptive: mt_bandwidth = 1. else: mt_bandwidth = None con, freqs, times, n, _ = spectral_connectivity(data, method=method, mode=mode, indices=None, sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles) assert_true(n == n_epochs) assert_array_almost_equal(times_data, times) if mode == 'multitaper': upper_t = 0.95 lower_t = 0.5 else: # other estimates have higher variance upper_t = 0.8 lower_t = 0.75 # test the simulated signal if method == 'coh': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # we see something for zero-lag assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t)) if mode != 'cwt_morlet': idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) assert_true(np.all(con[1, 0, :idx[0]] < lower_t)) assert_true(np.all(con[1, 0, idx[1]:] < lower_t)) elif method == 'cohy': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # imaginary coh will be zero assert_true(np.all(np.imag(con[1, 0, idx[0]:idx[1]]) < lower_t)) # we see something for zero-lag assert_true(np.all(np.abs(con[1, 0, idx[0]:idx[1]]) > upper_t)) idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) if mode != 'cwt_morlet': assert_true(np.all(np.abs(con[1, 0, :idx[0]]) < lower_t)) assert_true(np.all(np.abs(con[1, 0, idx[1]:]) < lower_t)) elif method == 'imcoh': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # imaginary coh will be zero assert_true(np.all(con[1, 0, idx[0]:idx[1]] < lower_t)) idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) assert_true(np.all(con[1, 0, :idx[0]] < lower_t)) assert_true(np.all(con[1, 0, idx[1]:] < lower_t)) # compute same connections using indices and 2 jobs, # also add a second method indices = tril_indices(n_signals, -1) test_methods = (method, _CohEst) combo = True if method == 'coh' else False stc_data = _stc_gen(data, sfreq, tmin) con2, freqs2, times2, n2, _ = spectral_connectivity(stc_data, method=test_methods, mode=mode, indices=indices, sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, tmin=tmin, tmax=tmax, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles, n_jobs=2) assert_true(isinstance(con2, list)) assert_true(len(con2) == 2) if method == 'coh': assert_array_almost_equal(con2[0], con2[1]) con2 = con2[0] # only keep the first method # we get the same result for the probed connections assert_array_almost_equal(freqs, freqs2) assert_array_almost_equal(con[indices], con2) assert_true(n == n2) assert_array_almost_equal(times_data, times2) # compute same connections for two bands, fskip=1, and f. avg. fmin = (5., 15.) fmax = (15., 30.) con3, freqs3, times3, n3, _ = spectral_connectivity(data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=fmin, fmax=fmax, fskip=1, faverage=True, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles) assert_true(isinstance(freqs3, list)) assert_true(len(freqs3) == len(fmin)) for i in range(len(freqs3)): assert_true(np.all((freqs3[i] >= fmin[i]) & (freqs3[i] <= fmax[i]))) # average con2 "manually" and we get the same result for i in range(len(freqs3)): freq_idx = np.searchsorted(freqs2, freqs3[i]) con2_avg = np.mean(con2[:, freq_idx], axis=1) assert_array_almost_equal(con2_avg, con3[:, i])
def test_spectral_connectivity(): """Test frequency-domain connectivity methods""" # XXX For some reason on 14 Oct 2015 Travis started timing out on this # test, so for a quick workaround we will skip it: if os.getenv('TRAVIS', 'false') == 'true': raise SkipTest('Travis is broken') # Use a case known to have no spurious correlations (it would bad if # nosetests could randomly fail): np.random.seed(0) sfreq = 50. n_signals = 3 n_epochs = 8 n_times = 256 tmin = 0. tmax = (n_times - 1) / sfreq data = np.random.randn(n_epochs, n_signals, n_times) times_data = np.linspace(tmin, tmax, n_times) # simulate connectivity from 5Hz..15Hz fstart, fend = 5.0, 15.0 for i in range(n_epochs): with warnings.catch_warnings(record=True): warnings.simplefilter('always') data[i, 1, :] = band_pass_filter(data[i, 0, :], sfreq, fstart, fend) # add some noise, so the spectrum is not exactly zero data[i, 1, :] += 1e-2 * np.random.randn(n_times) # First we test some invalid parameters: assert_raises(ValueError, spectral_connectivity, data, method='notamethod') assert_raises(ValueError, spectral_connectivity, data, mode='notamode') # test invalid fmin fmax settings assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=10 + 0.5 * (sfreq / float(n_times))) assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=5) assert_raises(ValueError, spectral_connectivity, data, fmin=(0, 11), fmax=(5, 10)) assert_raises(ValueError, spectral_connectivity, data, fmin=(11, ), fmax=(12, 15)) methods = [ 'coh', 'cohy', 'imcoh', [ 'plv', 'ppc', 'pli', 'pli2_unbiased', 'wpli', 'wpli2_debiased', 'coh' ] ] modes = ['multitaper', 'fourier', 'cwt_morlet'] # define some frequencies for cwt cwt_frequencies = np.arange(3, 24.5, 1) for mode in modes: for method in methods: if method == 'coh' and mode == 'multitaper': # only check adaptive estimation for coh to reduce test time check_adaptive = [False, True] else: check_adaptive = [False] if method == 'coh' and mode == 'cwt_morlet': # so we also test using an array for num cycles cwt_n_cycles = 7. * np.ones(len(cwt_frequencies)) else: cwt_n_cycles = 7. for adaptive in check_adaptive: if adaptive: mt_bandwidth = 1. else: mt_bandwidth = None con, freqs, times, n, _ = spectral_connectivity( data, method=method, mode=mode, indices=None, sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles) assert_true(n == n_epochs) assert_array_almost_equal(times_data, times) if mode == 'multitaper': upper_t = 0.95 lower_t = 0.5 else: # other estimates have higher variance upper_t = 0.8 lower_t = 0.75 # test the simulated signal if method == 'coh': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # we see something for zero-lag assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t)) if mode != 'cwt_morlet': idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) assert_true(np.all(con[1, 0, :idx[0]] < lower_t)) assert_true(np.all(con[1, 0, idx[1]:] < lower_t)) elif method == 'cohy': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # imaginary coh will be zero assert_true( np.all(np.imag(con[1, 0, idx[0]:idx[1]]) < lower_t)) # we see something for zero-lag assert_true( np.all(np.abs(con[1, 0, idx[0]:idx[1]]) > upper_t)) idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) if mode != 'cwt_morlet': assert_true( np.all(np.abs(con[1, 0, :idx[0]]) < lower_t)) assert_true( np.all(np.abs(con[1, 0, idx[1]:]) < lower_t)) elif method == 'imcoh': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # imaginary coh will be zero assert_true(np.all(con[1, 0, idx[0]:idx[1]] < lower_t)) idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) assert_true(np.all(con[1, 0, :idx[0]] < lower_t)) assert_true(np.all(con[1, 0, idx[1]:] < lower_t)) # compute same connections using indices and 2 jobs indices = tril_indices(n_signals, -1) if not isinstance(method, list): test_methods = (method, _CohEst) else: test_methods = method stc_data = _stc_gen(data, sfreq, tmin) con2, freqs2, times2, n2, _ = spectral_connectivity( stc_data, method=test_methods, mode=mode, indices=indices, sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, tmin=tmin, tmax=tmax, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles, n_jobs=2) assert_true(isinstance(con2, list)) assert_true(len(con2) == len(test_methods)) if method == 'coh': assert_array_almost_equal(con2[0], con2[1]) if not isinstance(method, list): con2 = con2[0] # only keep the first method # we get the same result for the probed connections assert_array_almost_equal(freqs, freqs2) assert_array_almost_equal(con[indices], con2) assert_true(n == n2) assert_array_almost_equal(times_data, times2) else: # we get the same result for the probed connections assert_true(len(con) == len(con2)) for c, c2 in zip(con, con2): assert_array_almost_equal(freqs, freqs2) assert_array_almost_equal(c[indices], c2) assert_true(n == n2) assert_array_almost_equal(times_data, times2) # compute same connections for two bands, fskip=1, and f. avg. fmin = (5., 15.) fmax = (15., 30.) con3, freqs3, times3, n3, _ = spectral_connectivity( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=fmin, fmax=fmax, fskip=1, faverage=True, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles) assert_true(isinstance(freqs3, list)) assert_true(len(freqs3) == len(fmin)) for i in range(len(freqs3)): assert_true( np.all((freqs3[i] >= fmin[i]) & (freqs3[i] <= fmax[i]))) # average con2 "manually" and we get the same result if not isinstance(method, list): for i in range(len(freqs3)): freq_idx = np.searchsorted(freqs2, freqs3[i]) con2_avg = np.mean(con2[:, freq_idx], axis=1) assert_array_almost_equal(con2_avg, con3[:, i]) else: for j in range(len(con2)): for i in range(len(freqs3)): freq_idx = np.searchsorted(freqs2, freqs3[i]) con2_avg = np.mean(con2[j][:, freq_idx], axis=1) assert_array_almost_equal(con2_avg, con3[j][:, i])
def plot_connectivity_circle_cvu(con, nodes_numberless, indices=None, n_lines=10000, node_colors=None, colormap='YlOrRd', fig=None, reqrois=[], suppress_extra_rois=False, node_angles=None, node_width=None, facecolor='black', textcolor='white', node_edgecolor='black',linewidth=1.5, vmin=None,vmax=None, colorbar=False, title=None, fontsize_names='auto'): """Visualize connectivity as a circular graph. Note: This code is based on the circle graph example by Nicolas P. Rougier http://www.loria.fr/~rougier/coding/recipes.html This function replicates functionality from MNE python, by Martin Luessi and others. Many changes are made from the MNE python version. Parameters ---------- con : array Connectivity scores. Can be a square matrix, or a 1D array. If a 1D array is provided, "indices" has to be used to define the connection indices. nodes_numberless : list of str Node names. The order corresponds to the order in con. indices : tuple of arrays | None Two arrays with indices of connections for which the connections strenghts are defined in con. Only needed if con is a 1D array. n_lines : int | None If not None, only the n_lines strongest connections (strenght=abs(con)) are drawn. node_angles : array, shape=(len(nodes_numberless,)) | None Array with node positions in degrees. If None, the nodes are equally spaced on the circle. See mne.viz.circular_layout. node_width : float | None Width of each node in degrees. If None, "360. / len(nodes_numberless)" is used. node_colors : list of tuples | list of str List with the color to use for each node. If fewer colors than nodes are provided, the colors will be repeated. Any color supported by matplotlib can be used, e.g., RGBA tuples, named colors. facecolor : str Color to use for background. See matplotlib.colors. textcolor : str Color to use for text. See matplotlib.colors. node_edgecolor : str Color to use for lines around nodes. See matplotlib.colors. linewidth : float Line width to use for connections. colormap : str Colormap to use for coloring the connections. vmin : float | None Minimum value for colormap. If None, it is determined automatically. vmax : float | None Maximum value for colormap. If None, it is determined automatically. colorbar : bool Display a colorbar or not. title : str The figure title. fontsize_names : int | str The fontsize for the node labels. If 'auto', the program attempts to determine a reasonable size. 'auto' is the default value. Returns ------- fig : instance of pyplot.Figure The figure handle. """ n_nodes = len(nodes_numberless) #start_hemi = nodes_numberless[0][:3] #end_hemi = nodes_numberless[-1][:3] #n_starthemi = sum(map(lambda lb:lb[:3]==start_hemi,nodes_numberless)) #n_endhemi = sum(map(lambda lb:lb[:3]==end_hemi,nodes_numberless)) if node_angles is not None: if len(node_angles) != n_nodes: raise ValueError('node_angles has to be the same length ' 'as nodes_numberless') # convert it to radians node_angles = node_angles * np.pi / 180 else: # uniform layout on unit circle node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False) if node_width is None: node_width = 2 * np.pi / n_nodes else: node_width = node_width * np.pi / 180 # handle 1D and 2D connectivity information if con.ndim == 1: if indices is None: raise ValueError('indices has to be provided if con.ndim == 1') #we use 1D indices elif con.ndim == 2: if con.shape[0] != n_nodes or con.shape[1] != n_nodes: raise ValueError('con has to be 1D or a square matrix') # we use the lower-triangular part indices = tril_indices(n_nodes, -1) con = con[indices] else: raise ValueError('con has to be 1D or a square matrix') # get the colormap if isinstance(colormap, basestring): colormap = pl.get_cmap(colormap) # Make figure background the same colors as axes if fig==None: fig = pl.figure(figsize=(5, 5), facecolor=facecolor) else: fig = pl.figure(num=fig.number) # Use a polar axes axes = pl.subplot(111, polar=True, axisbg=facecolor) #else: # Use the first axis already in the figure #axes = fig.get_axes()[0] # No ticks, we'll put our own pl.xticks([]) pl.yticks([]) # Set y axes limit pl.ylim(0, 10) #pl.ylim(ymin=0) # Draw lines between connected nodes, only draw the strongest connections if n_lines is not None and len(con) > n_lines: con_thresh = np.sort(np.abs(con).ravel())[-n_lines] else: con_thresh = 0. # get the connections which we are drawing and sort by connection strength # this will allow us to draw the strongest connections first con_abs = np.abs(con) con_draw_idx = np.where(con_abs >= con_thresh)[0] con = con[con_draw_idx] con_abs = con_abs[con_draw_idx] indices = [ind[con_draw_idx] for ind in indices] # input is already sorted #sort_idx = np.argsort(con_abs) #con_abs = con_abs[sort_idx] #con = con[sort_idx] #indices = [ind[sort_idx] for ind in indices] # Get vmin vmax for color scaling if np.size(con)>0: if vmin is None: vmin = np.min(con[np.abs(con) >= con_thresh]) if vmax is None: vmax = np.max(con) vrange = vmax - vmin # We want o add some "noise" to the start and end position of the # edges: We modulate the noise with the number of connections of the # node and the connection strength, such that the strongest connections # are closer to the node center nodes_n_con = np.zeros((n_nodes), dtype=np.int) for i, j in zip(indices[0], indices[1]): nodes_n_con[i] += 1 nodes_n_con[j] += 1 # initalize random number generator so plot is reproducible rng = np.random.mtrand.RandomState(seed=0) n_con = len(indices[0]) noise_max = 0.25 * node_width start_noise = rng.uniform(-noise_max, noise_max, n_con) end_noise = rng.uniform(-noise_max, noise_max, n_con) nodes_n_con_seen = np.zeros_like(nodes_n_con) for i, (start, end) in enumerate(zip(indices[0], indices[1])): nodes_n_con_seen[start] += 1 nodes_n_con_seen[end] += 1 start_noise[i] *= ((nodes_n_con[start] - nodes_n_con_seen[start]) / nodes_n_con[start]) end_noise[i] *= ((nodes_n_con[end] - nodes_n_con_seen[end]) / nodes_n_con[end]) # scale connectivity for colormap (vmin<=>0, vmax<=>1) if np.size(con)>0: con_val_scaled = (con - vmin) / vrange # Finally, we draw the connections for pos, (i, j) in enumerate(zip(indices[0], indices[1])): # Start point t0, r0 = node_angles[i], 7 # End point t1, r1 = node_angles[j], 7 # Some noise in start and end point t0 += start_noise[pos] t1 += end_noise[pos] verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)] codes = [m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4, m_path.Path.LINETO] path = m_path.Path(verts, codes) color = colormap(con_val_scaled[pos]) # Actual line patch = m_patches.PathPatch(path, fill=False, edgecolor=color, linewidth=linewidth, alpha=1.) axes.add_patch(patch) # Draw ring with colored nodes #radii = np.ones(n_nodes) * 8 radii=np.ones(n_nodes)-.2 bars = axes.bar(node_angles, radii, width=node_width, bottom=7.2, edgecolor=node_edgecolor, linewidth=0, facecolor='.9', align='center',zorder=10) for bar, color in zip(bars, node_colors): bar.set_facecolor(color) # Draw node labels #basic idea -- check for "too close" pairs. too close is pi/50 #remove smallest "too close" pairs. if multiple tied within a segment, #remove pairs at equal spacing. #calculate each segment individually and find the extent of the segment. #TODO this parameter, too_close, could be modified and adjusted for #a variety of sizes if ever the circle could be panned (or if it were #merely made bigger). determining the proper value is a matter of #userspace testing too_close = np.pi/50 # get angles for text placement text_angles = get_labels_avg_idx(nodes_numberless,n_nodes,frac=1,pad=np.pi/400) #print reqrois segments = get_tooclose_segments(text_angles,too_close,reqrois) for segment in segments: prune_segment(text_angles,segment,too_close) #print suppress_extra_rois, len(reqrois) if suppress_extra_rois and len(reqrois)>0: for name in text_angles.keys(): if name not in reqrois: del text_angles[name] if fontsize_names=='auto': fontsize_names=10 #TODO segments with many guaranteed ROIs are potentially spatially skewed #this is probably not worth fixing #now calculate how many pairs must be removed and remove them at equal #spacing. there should be no more than theta/(n-1) >= pi/50 pairs where #theta is the extent and n is the number of pairs. #n-1 is used because each segment holds one item by default #for angles,hemi in [(text_angles_sh,start_hemi),(text_angles_eh,end_hemi)]: # for name in angles: if fontsize_names=='auto': fontsize_names=8 for name in text_angles: angle_rad = text_angles[name] #if hemi is end_hemi: # angle_rad+=np.pi angle_deg = 180*angle_rad/np.pi if angle_deg >= 270 or angle_deg < 90: ha = 'left' else: # Flip the label, so text is always upright angle_deg += 180 ha = 'right' name_nonum=name.strip('1234567890') hemi='' axes.text(angle_rad, 8.2, hemi+name_nonum, size=fontsize_names, rotation=angle_deg, rotation_mode='anchor', horizontalalignment=ha, verticalalignment='center', color=textcolor) if title is not None: pl.subplots_adjust(left=0.2, bottom=0.2, right=0.8, top=0.75) pl.figtext(0.03, 0.95, title, color=textcolor, fontsize=14) else: pl.subplots_adjust(left=0.2, bottom=0.2, right=0.8, top=0.8) if colorbar: sm = pl.cm.ScalarMappable(cmap=colormap, norm=pl.normalize(vmin=vmin, vmax=vmax)) sm.set_array(np.linspace(vmin, vmax)) ax = fig.add_axes([.92, 0.03, .015, .25]) cb = fig.colorbar(sm, cax=ax) cb_yticks = pl.getp(cb.ax.axes, 'yticklabels') pl.setp(cb_yticks, color=textcolor) return fig
def test_spectral_connectivity(): """Test frequency-domain connectivity methods""" # First we test some invalid parameters: assert_raises(ValueError, spectral_connectivity, data, method='notamethod') assert_raises(ValueError, spectral_connectivity, data, mode='notamode') # test invalid fmin fmax settings assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=10 + 0.5 * (sfreq / float(n_times))) assert_raises(ValueError, spectral_connectivity, data, fmin=10, fmax=5) assert_raises(ValueError, spectral_connectivity, data, fmin=(0, 11), fmax=(5, 10)) assert_raises(ValueError, spectral_connectivity, data, fmin=(11, ), fmax=(12, 15)) methods = [ 'coh', 'imcoh', 'cohy', 'plv', 'ppc', 'pli', 'pli2_unbiased', 'wpli', 'wpli2_debiased', 'coh' ] modes = ['multitaper', 'fourier', 'cwt_morlet'] # define some frequencies for cwt cwt_frequencies = np.arange(3, 24.5, 1) for mode in modes: for method in methods: if method == 'coh' and mode == 'multitaper': # only check adaptive estimation for coh to reduce test time check_adaptive = [False, True] else: check_adaptive = [False] if method == 'coh' and mode == 'cwt_morlet': # so we also test using an array for num cycles cwt_n_cycles = 7. * np.ones(len(cwt_frequencies)) else: cwt_n_cycles = 7. for adaptive in check_adaptive: if adaptive: mt_bandwidth = 1. else: mt_bandwidth = None con, freqs, times, n, _ = spectral_connectivity( data, method=method, mode=mode, indices=None, sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles) assert_true(n == n_epochs) assert_array_almost_equal(times_data, times) if mode == 'multitaper': upper_t = 0.95 lower_t = 0.5 else: # other estimates have higher variance upper_t = 0.8 lower_t = 0.75 # test the simulated signal if method == 'coh': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # we see something for zero-lag assert_true(np.all(con[1, 0, idx[0]:idx[1]] > upper_t)) if mode != 'cwt_morlet': idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) assert_true(np.all(con[1, 0, :idx[0]] < lower_t)) assert_true(np.all(con[1, 0, idx[1]:] < lower_t)) elif method == 'cohy': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # imaginary coh will be zero assert_true( np.all(np.imag(con[1, 0, idx[0]:idx[1]]) < lower_t)) # we see something for zero-lag assert_true( np.all(np.abs(con[1, 0, idx[0]:idx[1]]) > upper_t)) idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) if mode != 'cwt_morlet': assert_true( np.all(np.abs(con[1, 0, :idx[0]]) < lower_t)) assert_true( np.all(np.abs(con[1, 0, idx[1]:]) < lower_t)) elif method == 'imcoh': idx = np.searchsorted(freqs, (fstart + 1, fend - 1)) # imaginary coh will be zero assert_true(np.all(con[1, 0, idx[0]:idx[1]] < lower_t)) idx = np.searchsorted(freqs, (fstart - 1, fend + 1)) assert_true(np.all(con[1, 0, :idx[0]] < lower_t)) assert_true(np.all(con[1, 0, idx[1]:] < lower_t)) # compute same connections using indices and 2 jobs, # also add a second method indices = tril_indices(n_signals, -1) test_methods = (method, _CohEst) combo = True if method == 'coh' else False stc_data = _stc_gen(data, sfreq, tmin) con2, freqs2, times2, n2, _ = spectral_connectivity( stc_data, method=test_methods, mode=mode, indices=indices, sfreq=sfreq, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, tmin=tmin, tmax=tmax, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles, n_jobs=2) assert_true(isinstance(con2, list)) assert_true(len(con2) == 2) if method == 'coh': assert_array_almost_equal(con2[0], con2[1]) con2 = con2[0] # only keep the first method # we get the same result for the probed connections assert_array_almost_equal(freqs, freqs2) assert_array_almost_equal(con[indices], con2) assert_true(n == n2) assert_array_almost_equal(times_data, times2) # compute same connections for two bands, fskip=1, and f. avg. fmin = (5., 15.) fmax = (15., 30.) con3, freqs3, times3, n3, _ = spectral_connectivity( data, method=method, mode=mode, indices=indices, sfreq=sfreq, fmin=fmin, fmax=fmax, fskip=1, faverage=True, mt_adaptive=adaptive, mt_low_bias=True, mt_bandwidth=mt_bandwidth, cwt_frequencies=cwt_frequencies, cwt_n_cycles=cwt_n_cycles) assert_true(isinstance(freqs3, list)) assert_true(len(freqs3) == len(fmin)) for i in range(len(freqs3)): assert_true( np.all((freqs3[i] >= fmin[i]) & (freqs3[i] <= fmax[i]))) # average con2 "manually" and we get the same result for i in range(len(freqs3)): freq_idx = np.searchsorted(freqs2, freqs3[i]) con2_avg = np.mean(con2[:, freq_idx], axis=1) assert_array_almost_equal(con2_avg, con3[:, i])
def plot_connectivity_circle_cvu(con, nodes_numberless, indices=None, n_lines=10000, node_colors=None, colormap='YlOrRd', fig=None, reqrois=[], suppress_extra_rois=False, node_angles=None, node_width=None, facecolor='black', textcolor='white', node_edgecolor='black', linewidth=1.5, vmin=None, vmax=None, colorbar=False, title=None, fontsize_names='auto', bilateral_symmetry=True): """Visualize connectivity as a circular graph. Note: This code is originally taken from public open-source examples in matplotlib by Nicolas P. Rougier. It was adapted for use in MNE python, primarily by Martin Luessi, but also by all the other contributors to MNE python. There are some differences between the current version and the MNE python version. Most importantly, the current version offers less flexibility of the layout of the plot and has algorithms to determine this layout automatically given the ordering of the CVU dataset. Each hemisphere takes up roughly half the space and the left hemisphere is always on the left side of the plot. Then there is a very complex and poorly documented algorithm to randomly suppress extra label names so that all of the label names that result are readable. Note that the suppression of label names can be overwritten in the GUI although it is quite effortful, typically it is recommended to do image postprocessing instead. Parameters ---------- con : array Connectivity scores. Can be a square matrix, or a 1D array. If a 1D array is provided, "indices" has to be used to define the connection indices. nodes_numberless : list of str Node names. The order corresponds to the order in con. indices : tuple of arrays | None Two arrays with indices of connections for which the connections strenghts are defined in con. Only needed if con is a 1D array. n_lines : int | None If not None, only the n_lines strongest connections (strenght=abs(con)) are drawn. node_angles : array, shape=(len(nodes_numberless,)) | None Array with node positions in degrees. If None, the nodes are equally spaced on the circle. See mne.viz.circular_layout. node_width : float | None Width of each node in degrees. If None, "360. / len(nodes_numberless)" is used. node_colors : list of tuples | list of str List with the color to use for each node. If fewer colors than nodes are provided, the colors will be repeated. Any color supported by matplotlib can be used, e.g., RGBA tuples, named colors. facecolor : str Color to use for background. See matplotlib.colors. textcolor : str Color to use for text. See matplotlib.colors. node_edgecolor : str Color to use for lines around nodes. See matplotlib.colors. linewidth : float Line width to use for connections. colormap : str Colormap to use for coloring the connections. vmin : float | None Minimum value for colormap. If None, it is determined automatically. vmax : float | None Maximum value for colormap. If None, it is determined automatically. colorbar : bool Display a colorbar or not. title : str The figure title. fontsize_names : int | str The fontsize for the node labels. If 'auto', the program attempts to determine a reasonable size. 'auto' is the default value. Returns ------- fig : instance of pyplot.Figure The figure handle. """ n_nodes = len(nodes_numberless) #reverse the lower hemisphere so that the circle is bilaterally symmetric start_hemi = 'l' first_hemi = nodes_numberless[0][0] def find_pivot(ls, item): for i, l in enumerate(ls): if l[0] != item: return i hemi_pivot = find_pivot(nodes_numberless, first_hemi) if bilateral_symmetry: if start_hemi == first_hemi: nodes_numberless = (nodes_numberless[:hemi_pivot] + nodes_numberless[:hemi_pivot - 1:-1]) node_colors = (node_colors[:hemi_pivot] + node_colors[:hemi_pivot - 1:-1]) if indices.size > 0: indices = indices.copy() indices[np.where(indices >= hemi_pivot)] = ( n_nodes - 1 + hemi_pivot - indices[np.where(indices >= hemi_pivot)]) else: nodes_numberless = (nodes_numberless[hemi_pivot:] + nodes_numberless[hemi_pivot - 1::-1]) node_colors = (node_colors[hemi_pivot:] + node_colors[hemi_pivot - 1::-1]) if indices.size > 0: indices_x = indices.copy() indices_x[np.where(indices < hemi_pivot)] = ( n_nodes - 1 - indices[np.where(indices < hemi_pivot)]) indices_x[np.where(indices >= hemi_pivot)] = ( -hemi_pivot + indices[np.where(indices >= hemi_pivot)]) indices = indices_x del indices_x #if bilateral symmetry is turned off, then still put the #left hemisphere on the left side else: if start_hemi != first_hemi: nodes_numberless = (nodes_numberless[hemi_pivot:] + nodes_numberless[:hemi_pivot]) node_colors = (node_colors[hemi_pivot:] + node_colors[:hemi_pivot]) if indices.size > 0: indices_x = indices.copy() indices_x[np.where(indices < hemi_pivot)] = ( hemi_pivot + indices[np.where(indices < hemi_pivot)]) indices_x[np.where(indices >= hemi_pivot)] = ( -hemi_pivot + indices[np.where(indices >= hemi_pivot)]) indices = indices_x del indices_x if node_angles is not None: if len(node_angles) != n_nodes: raise ValueError('node_angles has to be the same length ' 'as nodes_numberless') # convert it to radians node_angles = node_angles * np.pi / 180 else: # uniform layout on unit circle node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False) node_angles += np.pi / 2 if node_width is None: node_width = 2 * np.pi / n_nodes else: node_width = node_width * np.pi / 180 # handle 1D and 2D connectivity information if con.ndim == 1: if indices is None: raise ValueError('indices has to be provided if con.ndim == 1') #we use 1D indices elif con.ndim == 2: if con.shape[0] != n_nodes or con.shape[1] != n_nodes: raise ValueError('con has to be 1D or a square matrix') # we use the lower-triangular part indices = tril_indices(n_nodes, -1) con = con[indices] else: raise ValueError('con has to be 1D or a square matrix') # get the colormap if isinstance(colormap, CustomColormap): colormap = colormap._get__pl() elif isinstance(colormap, basestring): colormap = pl.get_cmap(colormap) # Make figure background the same colors as axes if fig == None: fig = pl.figure(figsize=(5, 5), facecolor=facecolor) else: fig = pl.figure(num=fig.number) # Use a polar axes axes = pl.subplot(111, polar=True, axisbg=facecolor) #else: # Use the first axis already in the figure #axes = fig.get_axes()[0] # No ticks, we'll put our own pl.xticks([]) pl.yticks([]) # Set y axes limit pl.ylim(0, 10) #pl.ylim(ymin=0) axes.spines['polar'].set_visible(False) # Draw lines between connected nodes, only draw the strongest connections if n_lines is not None and len(con) > n_lines: con_thresh = np.sort(np.abs(con).ravel())[-n_lines] else: con_thresh = 0. # get the connections which we are drawing and sort by connection strength # this will allow us to draw the strongest connections first con_abs = np.abs(con) con_draw_idx = np.where(con_abs >= con_thresh)[0] con = con[con_draw_idx] con_abs = con_abs[con_draw_idx] indices = [ind[con_draw_idx] for ind in indices] # input is already sorted #sort_idx = np.argsort(con_abs) #con_abs = con_abs[sort_idx] #con = con[sort_idx] #indices = [ind[sort_idx] for ind in indices] # Get vmin vmax for color scaling if np.size(con) > 0: if vmin is None: vmin = np.min(con[np.abs(con) >= con_thresh]) if vmax is None: vmax = np.max(con) vrange = vmax - vmin # We want o add some "noise" to the start and end position of the # edges: We modulate the noise with the number of connections of the # node and the connection strength, such that the strongest connections # are closer to the node center nodes_n_con = np.zeros((n_nodes), dtype=np.int) for i, j in zip(indices[0], indices[1]): nodes_n_con[i] += 1 nodes_n_con[j] += 1 # initalize random number generator so plot is reproducible rng = np.random.mtrand.RandomState(seed=0) n_con = len(indices[0]) noise_max = 0.25 * node_width start_noise = rng.uniform(-noise_max, noise_max, n_con) end_noise = rng.uniform(-noise_max, noise_max, n_con) nodes_n_con_seen = np.zeros_like(nodes_n_con) for i, (start, end) in enumerate(zip(indices[0], indices[1])): nodes_n_con_seen[start] += 1 nodes_n_con_seen[end] += 1 start_noise[i] *= ((nodes_n_con[start] - nodes_n_con_seen[start]) / nodes_n_con[start]) end_noise[i] *= ((nodes_n_con[end] - nodes_n_con_seen[end]) / nodes_n_con[end]) # scale connectivity for colormap (vmin<=>0, vmax<=>1) if np.size(con) > 0: con_val_scaled = (con - vmin) / vrange # Finally, we draw the connections for pos, (i, j) in enumerate(zip(indices[0], indices[1])): # Start point t0, r0 = node_angles[i], 7 # End point t1, r1 = node_angles[j], 7 # Some noise in start and end point t0 += start_noise[pos] t1 += end_noise[pos] verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)] codes = [ m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4, m_path.Path.LINETO ] path = m_path.Path(verts, codes) color = colormap(con_val_scaled[pos]) # Actual line patch = m_patches.PathPatch(path, fill=False, edgecolor=color, linewidth=linewidth, alpha=1.) axes.add_patch(patch) # Draw ring with colored nodes #radii = np.ones(n_nodes) * 8 radii = np.ones(n_nodes) - .2 bars = axes.bar(node_angles, radii, width=node_width, bottom=7.2, edgecolor=node_edgecolor, linewidth=0, facecolor='.9', align='center', zorder=10) for bar, color in zip(bars, node_colors): bar.set_facecolor(color) # Draw node labels #basic idea -- check for "too close" pairs. too close is pi/50 #remove smallest "too close" pairs. if multiple tied within a segment, #remove pairs at equal spacing. #calculate each segment individually and find the extent of the segment. #TODO this parameter, too_close, could be modified and adjusted for #a variety of sizes if ever the circle could be panned (or if it were #merely made bigger). determining the proper value is a matter of #userspace testing too_close = np.pi / 50 # get angles for text placement text_angles = get_labels_avg_idx(nodes_numberless, n_nodes, frac=1, pad=np.pi / 400) #print reqrois segments = get_tooclose_segments(text_angles, too_close, reqrois) for segment in segments: prune_segment(text_angles, segment, too_close) #print suppress_extra_rois, len(reqrois) if suppress_extra_rois and len(reqrois) > 0: for name in text_angles.keys(): if name not in reqrois: del text_angles[name] if fontsize_names == 'auto': fontsize_names = 10 #TODO segments with many guaranteed ROIs are potentially spatially skewed #this is probably not worth fixing #now calculate how many pairs must be removed and remove them at equal #spacing. there should be no more than theta/(n-1) >= pi/50 pairs where #theta is the extent and n is the number of pairs. #n-1 is used because each segment holds one item by default #for angles,hemi in [(text_angles_sh,start_hemi),(text_angles_eh,end_hemi)]: # for name in angles: if fontsize_names == 'auto': fontsize_names = 8 for name in text_angles: angle_rad = text_angles[name] + np.pi / 2 #if hemi is end_hemi: # angle_rad+=np.pi angle_deg = 180 * angle_rad / np.pi if angle_deg >= 270 or angle_deg < 90: ha = 'left' else: # Flip the label, so text is always upright angle_deg += 180 ha = 'right' name_nonum = name.strip('1234567890') hemi = '' axes.text(angle_rad, 8.2, hemi + name_nonum, size=fontsize_names, rotation=angle_deg, rotation_mode='anchor', horizontalalignment=ha, verticalalignment='center', color=textcolor) if title is not None: pl.subplots_adjust(left=0.2, bottom=0.2, right=0.8, top=0.75) pl.figtext(0.03, 0.95, title, color=textcolor, fontsize=14) else: pl.subplots_adjust(left=0.2, bottom=0.2, right=0.8, top=0.8) if colorbar: sm = pl.cm.ScalarMappable(cmap=colormap, norm=pl.normalize(vmin=vmin, vmax=vmax)) sm.set_array(np.linspace(vmin, vmax)) ax = fig.add_axes([.92, 0.03, .015, .25]) cb = fig.colorbar(sm, cax=ax) cb_yticks = pl.getp(cb.ax.axes, 'yticklabels') pl.setp(cb_yticks, color=textcolor) return fig, node_angles