# -----------------------------------------------------------------------------

# Apply RESS
out, maps, _ = ress.RESS(data, sfreq=sfreq, peak_freq=target, return_maps=True)

# Compute PSD
nfft = 250
df = sfreq / nfft  # frequency resolution
bins, psd = ss.welch(out.squeeze(1),
                     sfreq,
                     window="hamming",
                     nperseg=nfft,
                     noverlap=125,
                     axis=0)
psd = psd.mean(axis=1, keepdims=True)  # average over trials
snr = snr_spectrum(psd, bins, skipbins=2, n_avg=2)

f, ax = plt.subplots(1)
ax.plot(bins, snr, 'o', label='SNR')
ax.plot(bins[bins == target], snr[bins == target], 'ro', label='Target SNR')
ax.axhline(1, ls=':', c='grey', zorder=0)
ax.axvline(target, ls=':', c='grey', zorder=0)
ax.set_ylabel('SNR (a.u.)')
ax.set_xlabel('Frequency (Hz)')
ax.set_xlim([0, 40])

###############################################################################
# Project components back into sensor space to see the effects of RESS on the
# average SSVEP.

proj = matmul3d(out, maps)
def test_ress(target, n_trials, show=False):
    """Test RESS."""
    sfreq = 250
    data, source = create_data(n_times=1000,
                               n_trials=n_trials,
                               freq=target,
                               sfreq=sfreq,
                               show=False)

    out = ress.RESS(data, sfreq=sfreq, peak_freq=target)

    nfft = 500
    bins, psd = ss.welch(out.squeeze(1),
                         sfreq,
                         window="boxcar",
                         nperseg=nfft,
                         noverlap=0,
                         axis=0,
                         average='mean')
    # psd = np.abs(np.fft.fft(out, nfft, axis=0))
    # psd = psd[0:psd.shape[0] // 2 + 1]
    # bins = np.linspace(0, sfreq // 2, psd.shape[0])
    # print(psd.shape)
    # print(bins[:10])

    psd = psd.mean(axis=-1, keepdims=True)  # average over trials
    snr = snr_spectrum(psd + psd.max() / 20, bins, skipbins=1, n_avg=2)
    # snr = snr.mean(1)
    if show:
        f, ax = plt.subplots(2)
        ax[0].plot(bins, snr, ':o')
        ax[0].axhline(1, ls=':', c='grey', zorder=0)
        ax[0].axvline(target, ls=':', c='grey', zorder=0)
        ax[0].set_ylabel('SNR (a.u.)')
        ax[0].set_xlabel('Frequency (Hz)')
        ax[0].set_xlim([0, 40])
        ax[0].set_ylim([0, 10])
        ax[1].plot(bins, psd)
        ax[1].axvline(target, ls=':', c='grey', zorder=0)
        ax[1].set_ylabel('PSD')
        ax[1].set_xlabel('Frequency (Hz)')
        ax[1].set_xlim([0, 40])
        # plt.show()

    assert snr[bins == target] > 10
    assert (snr[(bins <= target - 2) | (bins >= target + 2)] < 2).all()

    # test multiple components
    out, maps = ress.RESS(data,
                          sfreq=sfreq,
                          peak_freq=target,
                          n_keep=1,
                          return_maps=True)
    _ = ress.RESS(data, sfreq=sfreq, peak_freq=target, n_keep=2)
    _ = ress.RESS(data, sfreq=sfreq, peak_freq=target, n_keep=-1)

    proj = matmul3d(out, maps.T)
    assert proj.shape == data.shape

    if show:
        f, ax = plt.subplots(data.shape[1], 2, sharey='col')
        for c in range(data.shape[1]):
            ax[c, 0].plot(data[:, c].mean(-1), lw=.5, label='data')
            ax[c, 1].plot(proj[:, c].mean(-1), lw=.5, label='projection')
            if c < data.shape[1]:
                ax[c, 0].set_xticks([])
                ax[c, 1].set_xticks([])

        ax[0, 0].set_title('Before')
        ax[0, 1].set_title('After')
        plt.legend()
        plt.show()
Beispiel #3
0
def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False):
    """Test RESS."""
    sfreq = 250
    n_keep = 1
    n_chans = 10
    n_times = 1000
    data, source = create_data(n_times=n_times,
                               n_trials=n_trials,
                               n_chans=n_chans,
                               freq=target,
                               sfreq=sfreq,
                               show=False)

    out = ress.RESS(data,
                    sfreq=sfreq,
                    peak_freq=target,
                    neig_freq=neig_freq,
                    peak_width=peak_width,
                    neig_width=neig_width,
                    n_keep=n_keep)

    nfft = 500
    bins, psd = ss.welch(out.squeeze(1),
                         sfreq,
                         window="boxcar",
                         nperseg=nfft / (peak_width * 2),
                         noverlap=0,
                         axis=0,
                         average='mean')
    # psd = np.abs(np.fft.fft(out, nfft, axis=0))
    # psd = psd[0:psd.shape[0] // 2 + 1]
    # bins = np.linspace(0, sfreq // 2, psd.shape[0])
    # print(psd.shape)
    # print(bins[:10])

    psd = psd.mean(axis=-1, keepdims=True)  # average over trials
    snr = snr_spectrum(psd + psd.max() / 20, bins, skipbins=1, n_avg=2)
    # snr = snr.mean(1)
    if show:
        f, ax = plt.subplots(2)
        ax[0].plot(bins, snr, ':o')
        ax[0].axhline(1, ls=':', c='grey', zorder=0)
        ax[0].axvline(target, ls=':', c='grey', zorder=0)
        ax[0].set_ylabel('SNR (a.u.)')
        ax[0].set_xlabel('Frequency (Hz)')
        ax[0].set_xlim([0, 40])
        ax[0].set_ylim([0, 10])
        ax[1].plot(bins, psd)
        ax[1].axvline(target, ls=':', c='grey', zorder=0)
        ax[1].set_ylabel('PSD')
        ax[1].set_xlabel('Frequency (Hz)')
        ax[1].set_xlim([0, 40])
        # plt.show()

    assert snr[bins == target] > 10
    assert (snr[(bins <= target - 2) | (bins >= target + 2)] < 2).all()

    # test multiple components
    out, fromress, toress = ress.RESS(data,
                                      sfreq=sfreq,
                                      peak_freq=target,
                                      neig_freq=neig_freq,
                                      peak_width=peak_width,
                                      neig_width=neig_width,
                                      n_keep=n_keep,
                                      return_maps=True)

    proj = matmul3d(out, fromress)
    assert proj.shape == (n_times, n_chans, n_trials)

    if show:
        f, ax = plt.subplots(data.shape[1], 2, sharey='col')
        for c in range(data.shape[1]):
            ax[c, 0].plot(data[:, c].mean(-1), lw=.5, label='data')
            ax[c, 1].plot(proj[:, c].mean(-1), lw=.5, label='projection')
            if c < data.shape[1]:
                ax[c, 0].set_xticks([])
                ax[c, 1].set_xticks([])

        ax[0, 0].set_title('Before')
        ax[0, 1].set_title('After')
        plt.legend()

    # 2 comps
    _ = ress.RESS(data, sfreq=sfreq, peak_freq=target, n_keep=2)

    # All comps
    out, fromress, toress = ress.RESS(data,
                                      sfreq=sfreq,
                                      peak_freq=target,
                                      n_keep=-1,
                                      return_maps=True)

    if show:
        # Inspect mixing/unmixing matrices
        combined_data = np.array([toress, fromress, pinv(toress)])
        _max = np.amax(combined_data)

        f, ax = plt.subplots(3)
        ax[0].imshow(toress, label='toRESS')
        ax[0].set_title('toRESS')
        ax[1].imshow(fromress, label='fromRESS', vmin=-_max, vmax=_max)
        ax[1].set_title('fromRESS')
        ax[2].imshow(pinv(toress), vmin=-_max, vmax=_max)
        ax[2].set_title('toRESS$^{-1}$')
        plt.tight_layout()
        plt.show()

    print(np.sum(np.abs(pinv(toress) - fromress) >= .1))