Beispiel #1
0
def plot_tfr(
    tfr,
    time_cutoff,
    vmin,
    vmax,
    tl,
    cluster_correct=False,
    threshold=0.05,
    plot_colorbar=False,
    ax=None,
    cmap=None,
    stat_cutoff=None,
):
    from pymeg.contrast_tfr import get_tfr_stats

    # colorbar:
    from matplotlib.colors import LinearSegmentedColormap

    if cmap is None:
        cmap = LinearSegmentedColormap.from_list(
            "custom", ["blue", "lightblue", "lightgrey", "yellow", "red"],
            N=100)

    if stat_cutoff is None:
        stat_cutoff = time_cutoff

    # data:
    times, freqs, X = contrast_tfr.get_tfr(tfr, stat_cutoff)
    mask = None
    if cluster_correct:
        hash = joblib.hash([times, freqs, X, threshold])
        try:
            _, _, cluster_p_values, _ = cluster_correct[hash]
            sig = cluster_p_values.reshape((X.shape[1], X.shape[2]))
            mask = sig < threshold
        except KeyError:
            s = get_tfr_stats(times, freqs, X, threshold)
            _, _, cluster_p_values, _ = s[hash]
            sig = cluster_p_values.reshape((X.shape[1], X.shape[2]))
            mask = sig < threshold

    cax = pmi(
        plt.gca(),
        np.nanmean(X, 0),
        times,
        yvals=freqs,
        yscale="linear",
        vmin=vmin,
        vmax=vmax,
        mask=mask,
        mask_alpha=1,
        mask_cmap=cmap,
        cmap=cmap,
    )
    plt.xlim(time_cutoff)
    plt.ylim([freqs.min() - 0.5, freqs.max() + 0.5])
    ax.axvline(0, ls="--", lw=0.75, color="black")
    if plot_colorbar:
        plt.colorbar(cax, ticks=[vmin, 0, vmax])
    return ax
Beispiel #2
0
def plot_epoch_pair(
    tfr_data,
    vmin=-25,
    vmax=25,
    cmap="RdBu_r",
    gs=None,
    stats=False,
    threshold=0.05,
    ylabel=None,
):
    from matplotlib import gridspec
    import pylab as plt
    import joblib

    if gs is None:
        g = gridspec.GridSpec(1, 2)
    else:
        g = gridspec.GridSpecFromSubplotSpec(1,
                                             2,
                                             subplot_spec=gs,
                                             wspace=0.01,
                                             width_ratios=[1, 0.4])
    times, freq, tfr = None, None, None
    for epoch in ["stimulus", "response"]:
        row = 0
        if epoch == "stimulus":
            col = 0
            time_cutoff = (-0.35, 1.1)
            xticks = [0, 0.25, 0.5, 0.75, 1]
            yticks = [25, 50, 75, 100, 125]
            xmarker = [0, 1]
        else:
            col = 1
            time_cutoff = (-0.35, 0.1)
            xticks = [0]
            yticks = [1, 25, 50, 75, 100, 125]
            xmarker = [0, 1]

        plt.subplot(g[row, col])
        tdata = tfr_data.query('epoch=="%s"' % (epoch))
        if len(tdata) == 0:
            plt.yticks([], [""])
            plt.xticks([], [""])
            continue
        times, freqs, tfr = get_tfr(tdata, time_cutoff)

        mask = None
        if stats:
            hash = joblib.hash([times, freqs, tfr, threshold])
            try:
                _, _, cluster_p_values, _ = stats[hash]
            except KeyError:
                s = get_tfr_stats(times, freqs, tfr, threshold)
                _, _, cluster_p_values, _ = s[hash]

            sig = cluster_p_values.reshape((tfr.shape[1], tfr.shape[2]))
            mask = sig < threshold

        _ = pmi(
            plt.gca(),
            np.nanmean(tfr, 0),
            times,
            yvals=freqs,
            yscale="linear",
            vmin=vmin,
            vmax=vmax,
            mask=mask,
            mask_alpha=1,
            mask_cmap=cmap,
            cmap=cmap,
        )
        if (ylabel is not None) and (epoch == "stimulus"):
            plt.ylabel(ylabel, labelpad=-2, fontdict={"fontsize": 4})
        # for xmark in xmarker:
        #    plt.axvline(xmark, color='k', lw=1, zorder=-1, alpha=0.5)

        plt.yticks(yticks, [""] * len(yticks))
        plt.xticks(xticks, [""] * len(xticks))

        plt.tick_params(direction="inout", length=2, zorder=100)
        plt.xlim(time_cutoff)
        plt.ylim([1, 147.5])
        # plt.axhline(10, color='k', lw=1, alpha=0.5, linestyle='--')
        # plt.axhline(25, color='k', lw=1, alpha=0.5, linestyle=':')
        # plt.axhline(50, color='k', lw=1, alpha=0.5, linestyle=':')
        plt.axvline(0, color="k", lw=1, zorder=5, alpha=0.5)
        if epoch == "stimulus":
            plt.axvline(1, color="k", lw=1, zorder=5, alpha=0.5)
    return times, freqs, tfr
Beispiel #3
0
def plot_mosaic(
    tfr_data,
    vmin=-25,
    vmax=25,
    cmap="RdBu_r",
    ncols=4,
    epoch="stimulus",
    stats=False,
    threshold=0.05,
):

    if epoch == "stimulus":
        time_cutoff = (-0.5, 1.35)
        xticks = [0, 0.25, 0.5, 0.75, 1]
        xticklabels = ["0\nStim on", "", ".5", "", "1\nStim off"]
        yticks = [25, 50, 75, 100, 125]
        yticklabels = ["25", "", "75", "", "125"]
        xmarker = [0, 1]
        baseline = (-0.25, 0)
    else:
        time_cutoff = (-1, 0.5)
        xticks = [-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5]
        xticklabels = ["-1", "", "-0.5", "", "0\nResponse", "", "0.5"]
        yticks = [1, 25, 50, 75, 100, 125]
        yticklabels = ["1", "25", "", "75", "", "125"]
        xmarker = [0, 1]
        baseline = None
    from matplotlib import gridspec
    import pylab as plt
    import seaborn as sns

    contrast_tfr.set_jw_style()
    sns.set_style("ticks")
    nrows = (len(atlas_glasser.areas) // ncols) + 1
    gs = gridspec.GridSpec(nrows, ncols)
    gs.update(wspace=0.01, hspace=0.01)

    for i, (name, area) in enumerate(atlas_glasser.areas.items()):
        try:
            column = np.mod(i, ncols)
            row = i // ncols
            plt.subplot(gs[row, column])
            times, freqs, tfr = get_tfr(tfr_data.query('cluster=="%s"' % area),
                                        time_cutoff)
            # cax = plt.gca().pcolormesh(times, freqs, np.nanmean(
            #    tfr, 0), vmin=vmin, vmax=vmax, cmap=cmap, zorder=-2)
            mask = None

            if stats:
                import joblib

                hash = joblib.hash([times, freqs, tfr, threshold])
                try:
                    _, _, cluster_p_values, _ = stats[hash]
                except KeyError:
                    s = get_tfr_stats(times, freqs, tfr, threshold)
                    _, _, cluster_p_values, _ = s[hash]
                sig = cluster_p_values.reshape((tfr.shape[1], tfr.shape[2]))
                mask = sig < threshold
            cax = pmi(
                plt.gca(),
                np.nanmean(tfr, 0),
                times,
                yvals=freqs,
                yscale="linear",
                vmin=vmin,
                vmax=vmax,
                mask=mask,
                mask_alpha=1,
                mask_cmap=cmap,
                cmap=cmap,
            )

            # plt.grid(True, alpha=0.5)
            for xmark in xmarker:
                plt.axvline(xmark, color="k", lw=1, zorder=-1, alpha=0.5)

            plt.yticks(yticks, [""] * len(yticks))
            plt.xticks(xticks, [""] * len(xticks))
            set_title(name, times, freqs, plt.gca())
            plt.tick_params(direction="inout", length=2, zorder=100)
            plt.xlim(time_cutoff)
            plt.ylim([1, 147.5])
            plt.axhline(10, color="k", lw=1, alpha=0.5, linestyle="--")
        except ValueError as e:
            print(name, area, e)
    plt.subplot(gs[nrows - 2, 0])

    sns.despine(left=True, bottom=True)
    plt.subplot(gs[nrows - 1, 0])

    pmi(
        plt.gca(),
        np.nanmean(tfr, 0) * 0,
        times,
        yvals=freqs,
        yscale="linear",
        vmin=vmin,
        vmax=vmax,
        mask=None,
        mask_alpha=1,
        mask_cmap=cmap,
        cmap=cmap,
    )
    plt.xticks(xticks, xticklabels)
    plt.yticks(yticks, yticklabels)
    for xmark in xmarker:
        plt.axvline(xmark, color="k", lw=1, zorder=-1, alpha=0.5)
    if baseline is not None:
        plt.fill_between(baseline,
                         y1=[1, 1],
                         y2=[150, 150],
                         color="k",
                         alpha=0.5)
    plt.tick_params(direction="in", length=3)
    plt.xlim(time_cutoff)
    plt.ylim([1, 147.5])
    plt.xlabel("time [s]")
    plt.ylabel("Freq [Hz]")
    sns.despine(ax=plt.gca())
Beispiel #4
0
def plot_tfr(tfr,
             time_cutoff,
             vmin,
             vmax,
             tl,
             cluster_correct=False,
             threshold=0.05,
             plot_colorbar=False,
             ax=None,
             cmap=None,
             stat_cutoff=None,
             aspect=None,
             cluster=None,
             contrast_name=None,
             time_lock=None):
    from pymeg.contrast_tfr import get_tfr_stats

    # colorbar:
    from matplotlib.colors import LinearSegmentedColormap

    if cmap is None:
        cmap = LinearSegmentedColormap.from_list(
            "custom", ["blue", "lightblue", "lightgrey", "yellow", "red"],
            N=100)

    if stat_cutoff is None:
        stat_cutoff = time_cutoff

    # data:
    times, freqs, X = contrast_tfr.get_tfr(tfr, stat_cutoff)
    #import ipdb; ipdb.set_trace()
    ### Save data to data source file
    from conf_analysis.meg.figures import array_to_data_source_file
    panel = 'A' if 'all' in contrast_name else 'B'
    if not 'choice' in contrast_name:
        fnr = 2
    else:
        fnr = 'S6'
        panel = 'A'
    array_to_data_source_file(
        fnr, panel, cluster + str(time_lock), X, {
            'dim_0_subjects': np.arange(1, 16),
            'dim_1_frequencies': freqs,
            'dim_2_time': times
        })

    mask = None
    if cluster_correct:
        hash = joblib.hash([times, freqs, X, threshold])
        try:
            _, _, cluster_p_values, _ = cluster_correct[hash]
            sig = cluster_p_values.reshape((X.shape[1], X.shape[2]))
            mask = sig < threshold
        except KeyError:
            s = get_tfr_stats(times, freqs, X, threshold)
            _, _, cluster_p_values, _ = s[hash]
            sig = cluster_p_values.reshape((X.shape[1], X.shape[2]))
            mask = sig < threshold
    earliest_sig = None
    if mask is not None:
        idt = np.where(np.any(mask, 0).ravel())[0]
        idt = [
            t for t in idt
            if (time_cutoff[0] <= times[t]) and (times[t] <= time_cutoff[1])
        ]
        if len(idt) > 0:
            earliest_sig = times[idt[0]]

    freqs_idx = freqs >= 4
    Xb = np.nanmean(X, 0)[freqs_idx, :]
    freqsb = freqs[freqs_idx]
    cax = pmi(
        plt.gca(),
        Xb,
        times,
        yvals=freqsb,
        yscale="linear",
        vmin=vmin,
        vmax=vmax,
        mask=mask[freqs_idx, :],
        mask_alpha=1,
        mask_cmap=cmap,
        cmap=cmap,
    )
    plt.gca().set_aspect(aspect)
    plt.xlim(time_cutoff)
    plt.ylim([freqs.min() - 0.5, freqs.max() + 0.5])
    ax.axvline(0, ls="--", lw=0.75, color="black")
    ax.axvline(1, ls="--", lw=0.75, color="black")
    if plot_colorbar:
        plt.colorbar(cax, ticks=[vmin, 0, vmax])
    return ax, earliest_sig