def plot_tfr( tfr, time_cutoff, vmin, vmax, tl, cluster_correct=False, threshold=0.05, plot_colorbar=False, ax=None, cmap=None, stat_cutoff=None, ): from pymeg.contrast_tfr import get_tfr_stats # colorbar: from matplotlib.colors import LinearSegmentedColormap if cmap is None: cmap = LinearSegmentedColormap.from_list( "custom", ["blue", "lightblue", "lightgrey", "yellow", "red"], N=100) if stat_cutoff is None: stat_cutoff = time_cutoff # data: times, freqs, X = contrast_tfr.get_tfr(tfr, stat_cutoff) mask = None if cluster_correct: hash = joblib.hash([times, freqs, X, threshold]) try: _, _, cluster_p_values, _ = cluster_correct[hash] sig = cluster_p_values.reshape((X.shape[1], X.shape[2])) mask = sig < threshold except KeyError: s = get_tfr_stats(times, freqs, X, threshold) _, _, cluster_p_values, _ = s[hash] sig = cluster_p_values.reshape((X.shape[1], X.shape[2])) mask = sig < threshold cax = pmi( plt.gca(), np.nanmean(X, 0), times, yvals=freqs, yscale="linear", vmin=vmin, vmax=vmax, mask=mask, mask_alpha=1, mask_cmap=cmap, cmap=cmap, ) plt.xlim(time_cutoff) plt.ylim([freqs.min() - 0.5, freqs.max() + 0.5]) ax.axvline(0, ls="--", lw=0.75, color="black") if plot_colorbar: plt.colorbar(cax, ticks=[vmin, 0, vmax]) return ax
def plot_epoch_pair( tfr_data, vmin=-25, vmax=25, cmap="RdBu_r", gs=None, stats=False, threshold=0.05, ylabel=None, ): from matplotlib import gridspec import pylab as plt import joblib if gs is None: g = gridspec.GridSpec(1, 2) else: g = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs, wspace=0.01, width_ratios=[1, 0.4]) times, freq, tfr = None, None, None for epoch in ["stimulus", "response"]: row = 0 if epoch == "stimulus": col = 0 time_cutoff = (-0.35, 1.1) xticks = [0, 0.25, 0.5, 0.75, 1] yticks = [25, 50, 75, 100, 125] xmarker = [0, 1] else: col = 1 time_cutoff = (-0.35, 0.1) xticks = [0] yticks = [1, 25, 50, 75, 100, 125] xmarker = [0, 1] plt.subplot(g[row, col]) tdata = tfr_data.query('epoch=="%s"' % (epoch)) if len(tdata) == 0: plt.yticks([], [""]) plt.xticks([], [""]) continue times, freqs, tfr = get_tfr(tdata, time_cutoff) mask = None if stats: hash = joblib.hash([times, freqs, tfr, threshold]) try: _, _, cluster_p_values, _ = stats[hash] except KeyError: s = get_tfr_stats(times, freqs, tfr, threshold) _, _, cluster_p_values, _ = s[hash] sig = cluster_p_values.reshape((tfr.shape[1], tfr.shape[2])) mask = sig < threshold _ = pmi( plt.gca(), np.nanmean(tfr, 0), times, yvals=freqs, yscale="linear", vmin=vmin, vmax=vmax, mask=mask, mask_alpha=1, mask_cmap=cmap, cmap=cmap, ) if (ylabel is not None) and (epoch == "stimulus"): plt.ylabel(ylabel, labelpad=-2, fontdict={"fontsize": 4}) # for xmark in xmarker: # plt.axvline(xmark, color='k', lw=1, zorder=-1, alpha=0.5) plt.yticks(yticks, [""] * len(yticks)) plt.xticks(xticks, [""] * len(xticks)) plt.tick_params(direction="inout", length=2, zorder=100) plt.xlim(time_cutoff) plt.ylim([1, 147.5]) # plt.axhline(10, color='k', lw=1, alpha=0.5, linestyle='--') # plt.axhline(25, color='k', lw=1, alpha=0.5, linestyle=':') # plt.axhline(50, color='k', lw=1, alpha=0.5, linestyle=':') plt.axvline(0, color="k", lw=1, zorder=5, alpha=0.5) if epoch == "stimulus": plt.axvline(1, color="k", lw=1, zorder=5, alpha=0.5) return times, freqs, tfr
def plot_mosaic( tfr_data, vmin=-25, vmax=25, cmap="RdBu_r", ncols=4, epoch="stimulus", stats=False, threshold=0.05, ): if epoch == "stimulus": time_cutoff = (-0.5, 1.35) xticks = [0, 0.25, 0.5, 0.75, 1] xticklabels = ["0\nStim on", "", ".5", "", "1\nStim off"] yticks = [25, 50, 75, 100, 125] yticklabels = ["25", "", "75", "", "125"] xmarker = [0, 1] baseline = (-0.25, 0) else: time_cutoff = (-1, 0.5) xticks = [-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5] xticklabels = ["-1", "", "-0.5", "", "0\nResponse", "", "0.5"] yticks = [1, 25, 50, 75, 100, 125] yticklabels = ["1", "25", "", "75", "", "125"] xmarker = [0, 1] baseline = None from matplotlib import gridspec import pylab as plt import seaborn as sns contrast_tfr.set_jw_style() sns.set_style("ticks") nrows = (len(atlas_glasser.areas) // ncols) + 1 gs = gridspec.GridSpec(nrows, ncols) gs.update(wspace=0.01, hspace=0.01) for i, (name, area) in enumerate(atlas_glasser.areas.items()): try: column = np.mod(i, ncols) row = i // ncols plt.subplot(gs[row, column]) times, freqs, tfr = get_tfr(tfr_data.query('cluster=="%s"' % area), time_cutoff) # cax = plt.gca().pcolormesh(times, freqs, np.nanmean( # tfr, 0), vmin=vmin, vmax=vmax, cmap=cmap, zorder=-2) mask = None if stats: import joblib hash = joblib.hash([times, freqs, tfr, threshold]) try: _, _, cluster_p_values, _ = stats[hash] except KeyError: s = get_tfr_stats(times, freqs, tfr, threshold) _, _, cluster_p_values, _ = s[hash] sig = cluster_p_values.reshape((tfr.shape[1], tfr.shape[2])) mask = sig < threshold cax = pmi( plt.gca(), np.nanmean(tfr, 0), times, yvals=freqs, yscale="linear", vmin=vmin, vmax=vmax, mask=mask, mask_alpha=1, mask_cmap=cmap, cmap=cmap, ) # plt.grid(True, alpha=0.5) for xmark in xmarker: plt.axvline(xmark, color="k", lw=1, zorder=-1, alpha=0.5) plt.yticks(yticks, [""] * len(yticks)) plt.xticks(xticks, [""] * len(xticks)) set_title(name, times, freqs, plt.gca()) plt.tick_params(direction="inout", length=2, zorder=100) plt.xlim(time_cutoff) plt.ylim([1, 147.5]) plt.axhline(10, color="k", lw=1, alpha=0.5, linestyle="--") except ValueError as e: print(name, area, e) plt.subplot(gs[nrows - 2, 0]) sns.despine(left=True, bottom=True) plt.subplot(gs[nrows - 1, 0]) pmi( plt.gca(), np.nanmean(tfr, 0) * 0, times, yvals=freqs, yscale="linear", vmin=vmin, vmax=vmax, mask=None, mask_alpha=1, mask_cmap=cmap, cmap=cmap, ) plt.xticks(xticks, xticklabels) plt.yticks(yticks, yticklabels) for xmark in xmarker: plt.axvline(xmark, color="k", lw=1, zorder=-1, alpha=0.5) if baseline is not None: plt.fill_between(baseline, y1=[1, 1], y2=[150, 150], color="k", alpha=0.5) plt.tick_params(direction="in", length=3) plt.xlim(time_cutoff) plt.ylim([1, 147.5]) plt.xlabel("time [s]") plt.ylabel("Freq [Hz]") sns.despine(ax=plt.gca())
def plot_tfr(tfr, time_cutoff, vmin, vmax, tl, cluster_correct=False, threshold=0.05, plot_colorbar=False, ax=None, cmap=None, stat_cutoff=None, aspect=None, cluster=None, contrast_name=None, time_lock=None): from pymeg.contrast_tfr import get_tfr_stats # colorbar: from matplotlib.colors import LinearSegmentedColormap if cmap is None: cmap = LinearSegmentedColormap.from_list( "custom", ["blue", "lightblue", "lightgrey", "yellow", "red"], N=100) if stat_cutoff is None: stat_cutoff = time_cutoff # data: times, freqs, X = contrast_tfr.get_tfr(tfr, stat_cutoff) #import ipdb; ipdb.set_trace() ### Save data to data source file from conf_analysis.meg.figures import array_to_data_source_file panel = 'A' if 'all' in contrast_name else 'B' if not 'choice' in contrast_name: fnr = 2 else: fnr = 'S6' panel = 'A' array_to_data_source_file( fnr, panel, cluster + str(time_lock), X, { 'dim_0_subjects': np.arange(1, 16), 'dim_1_frequencies': freqs, 'dim_2_time': times }) mask = None if cluster_correct: hash = joblib.hash([times, freqs, X, threshold]) try: _, _, cluster_p_values, _ = cluster_correct[hash] sig = cluster_p_values.reshape((X.shape[1], X.shape[2])) mask = sig < threshold except KeyError: s = get_tfr_stats(times, freqs, X, threshold) _, _, cluster_p_values, _ = s[hash] sig = cluster_p_values.reshape((X.shape[1], X.shape[2])) mask = sig < threshold earliest_sig = None if mask is not None: idt = np.where(np.any(mask, 0).ravel())[0] idt = [ t for t in idt if (time_cutoff[0] <= times[t]) and (times[t] <= time_cutoff[1]) ] if len(idt) > 0: earliest_sig = times[idt[0]] freqs_idx = freqs >= 4 Xb = np.nanmean(X, 0)[freqs_idx, :] freqsb = freqs[freqs_idx] cax = pmi( plt.gca(), Xb, times, yvals=freqsb, yscale="linear", vmin=vmin, vmax=vmax, mask=mask[freqs_idx, :], mask_alpha=1, mask_cmap=cmap, cmap=cmap, ) plt.gca().set_aspect(aspect) plt.xlim(time_cutoff) plt.ylim([freqs.min() - 0.5, freqs.max() + 0.5]) ax.axvline(0, ls="--", lw=0.75, color="black") ax.axvline(1, ls="--", lw=0.75, color="black") if plot_colorbar: plt.colorbar(cax, ticks=[vmin, 0, vmax]) return ax, earliest_sig