Exemplo n.º 1
0
from pynfb.protocols.ssd.topomap_selector_ica import ICADialog
from PyQt5 import QtGui, QtWidgets

a = QtWidgets.QApplication([])

fig1, axes1 = plt.subplots(ncols=3)

for subj in range(3, 4):
    for day in range(2, 5):
        experiments = settings['subjects'][subj]
        experiment = experiments[day]
        reject = True
        with h5py.File('{}\\{}\\{}'.format(settings['dir'], experiment,
                                           'experiment_data.h5')) as f:
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            rejection, alpha, ica = load_rejections(f, reject_alpha=False)
            raw_before = OrderedDict()
            raw_after = OrderedDict()
            for j, name in enumerate(p_names):
                if j < 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs,
                                rejection if reject else None)
                    raw_before = add_data_simple(raw_before, name, x)
                elif j > len(p_names) - 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs,
                                rejection if reject else None)
                    print(x.shape)
                    raw_after = add_data_simple(raw_after, name, x)

        xx = np.concatenate([
            raw_before['Opened'],
Exemplo n.º 2
0
experiment = experiments[day]


def preproc(x, fs, rej=None):
    x = dc_blocker(x)
    x = fft_filter(x, fs, band=(0, 45))
    if rej is not None:
        x = np.dot(x, rej)
    return x


reject = False
with h5py.File('{}\\{}\\{}'.format(settings['dir'], experiment,
                                   'experiment_data.h5')) as f:
    fs, channels, p_names = get_info(f, settings['drop_channels'])
    rejection, alpha, ica = load_rejections(f, reject_alpha=True)
    raw_before = OrderedDict()
    for j, name in enumerate(p_names):
        x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs,
                    rejection if reject else None)
        raw_before = add_data(raw_before, name, x, j)

# plot raw data
ch_plot = ['C3', 'C4', 'P3', 'P4']  #, 'Pz', 'Fp1']
fig1, axes = plt.subplots(len(ch_plot), ncols=1, sharex=True, sharey=True)
print(axes)

#find median
x_all = []
for name, x in raw_before.items():
    x_all.append(np.abs(hilbert(fft_filter(x, fs, (4, 8)))))
Exemplo n.º 3
0
subj = 3
experiments = settings['subjects'][subj]
experiment = experiments[2]

def preproc(x, fs, rej=None):
    x = dc_blocker(x)
    x = fft_filter(x, fs, band=(0, 45))
    if rej is not None:
        x = np.dot(x, rej)
    return x


reject = False
with h5py.File('{}\\{}\\{}'.format(settings['dir'], experiment, 'experiment_data.h5')) as f:
    fs, channels, p_names = get_info(f, settings['drop_channels'])
    rejection, alpha, ica = load_rejections(f, reject_alpha=True)
    raw_before = OrderedDict()
    raw_after = OrderedDict()
    for j, name in enumerate(p_names):
        if j < 5:
            x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
            raw_before = add_data_simple(raw_before, name, x)
        elif j > len(p_names) - 3:
            x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
            print(x.shape)
            raw_after = add_data_simple(raw_after, name, x)



# plot raw data
ch_plot = ['C3', 'C4', 'P3', 'P4']#, 'Pz', 'Fp1']
Exemplo n.º 4
0
            mask_copy = mask.astype(int).copy()
            lengths_buffer = []
            for j, y in enumerate(mask_copy):
                if y:
                    lengths_buffer.append(x[j])
                elif len(lengths_buffer) > 0:
                    lengths.append(np.mean(lengths_buffer))
                    lengths_buffer = []
            print(lengths)
            return np.array(lengths)

        with h5py.File('{}\\{}\\{}'.format(dir_, experiment,
                                           'experiment_data.h5')) as f:
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            if reject:
                rejections = load_rejections(f, reject_alpha=True)[0]
            else:
                rejections = None
            spatial = f['protocol15/signals_stats/left/spatial_filter'][:]
            mu_band = f['protocol15/signals_stats/left/bandpass'][:]
            #mu_band = (12, 13)
            max_gap = 1 / min(mu_band) * 2
            min_sate_duration = max_gap * 2
            raw = OrderedDict()
            signal = OrderedDict()
            for j, name in enumerate(p_names):
                x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs,
                            rejections)
                raw = add_data(raw, name, x, j)
                signal = add_data(
                    signal, name,
Exemplo n.º 5
0
def plot_results(pilot_dir,
                 subj,
                 channel,
                 alpha_band=(9, 14),
                 theta_band=(3, 6),
                 drop_channels=None,
                 dc=False,
                 reject_alpha=True,
                 normalize_by='opened'):
    drop_channels = drop_channels or []
    cm = get_colors()
    fg = plt.figure(figsize=(30, 6))
    for j_s, experiment in enumerate(subj):
        with h5py.File('{}\\{}\\{}'.format(pilot_dir, experiment,
                                           'experiment_data.h5')) as f:
            rejections, top_alpha, top_ica = load_rejections(
                f, reject_alpha=reject_alpha)
            fs, channels, p_names = get_info(f, drop_channels)
            ch = channels.index(channel)
            #plt.plot(fft_filter(f['protocol6/raw_data'][:, ch], fs, band=(3, 35)))
            #plt.plot(fft_filter(np.dot(f['protocol6/raw_data'], rejections)[:, ch], fs, band=(3, 35)))
            #plt.show()
            #from scipy.signal import welch
            #plt.plot(*welch(f['protocol1/raw_data'][:60*500//2, channels.index('C3')], fs, nperseg=1000))
            #plt.plot(*welch(f['protocol1/raw_data'][60*500//2:, channels.index('C3')], fs, nperseg=1000))

            #plt.plot(*welch(f['protocol2/raw_data'][:30*500//2, channels.index('C3')], fs, nperseg=1000))
            #plt.plot(*welch(f['protocol2/raw_data'][30*500//2:, channels.index('C3')], fs, nperseg=1000))
            #plt.legend(['Close', 'Open', 'Left', 'Right'])
            #plt.show()

            # collect powers
            powers = OrderedDict()
            raw = OrderedDict()
            alpha = OrderedDict()
            pow_theta = []
            for j, name in enumerate(p_names):
                pow, alpha_x, x = get_protocol_power(f,
                                                     j,
                                                     fs,
                                                     rejections,
                                                     ch,
                                                     alpha_band,
                                                     dc=dc)
                if 'FB' in name:
                    pow_theta.append(
                        get_protocol_power(f,
                                           j,
                                           fs,
                                           rejections,
                                           ch,
                                           theta_band,
                                           dc=dc)[0].mean())
                powers = add_data(powers, name, pow, j)
                raw = add_data(raw, name, x, j)
                alpha = add_data(alpha, name, alpha_x, j)

            # plot rejections
            n_tops = top_ica.shape[1] + top_alpha.shape[1]
            for j_t in range(top_ica.shape[1]):
                ax = fg.add_subplot(
                    4, n_tops * len(subj),
                    n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1)
                ax.set_xlabel('ICA{}'.format(j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [
                    label for label in labels if label not in drop_channels
                ]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_ica[:, j_t],
                             pos=pos,
                             axes=ax,
                             show=False)
            for j_t in range(top_alpha.shape[1]):
                ax = fg.add_subplot(
                    4, n_tops * len(subj),
                    n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1 +
                    top_ica.shape[1])
                ax.set_xlabel('CSP{}'.format(j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [
                    label for label in labels if label not in drop_channels
                ]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_alpha[:, j_t],
                             pos=pos,
                             axes=ax,
                             show=False)

            # plot powers
            if normalize_by == 'opened':
                norm = powers['1. Opened'].mean()
            elif normalize_by == 'beta':
                norm = np.mean(pow_theta)
            else:
                print('WARNING: norm = 1')
            print('norm', norm)

            ax1 = fg.add_subplot(3, len(subj), j_s + 1)
            ax = fg.add_subplot(3, len(subj), j_s + len(subj) + 1)
            t = 0
            for j_p, ((name, pow),
                      (name, x)) in enumerate(zip(powers.items(),
                                                  raw.items())):
                if name == '2228. FB':
                    from scipy.signal import periodogram
                    fff = plt.figure()
                    fff.gca().plot(*periodogram(x, fs, nfft=fs * 3),
                                   c=cm[name.split()[1]])
                    plt.xlim(0, 80)
                    plt.ylim(0, 3e-11)
                    plt.show()
                print(name)
                time = np.arange(t, t + len(x)) / fs
                color = cm[''.join(
                    [i for i in name.split()[1] if not i.isdigit()])]
                ax1.plot(time, fft_filter(x, fs, (2, 45)), c=color, alpha=0.4)
                ax1.plot(time, alpha[name], c=color)
                t += len(x)
                ax.plot([j_p], [pow.mean() / norm],
                        'o',
                        c=color,
                        markersize=10)
                ax.errorbar([j_p], [pow.mean() / norm],
                            yerr=pow.std() / norm,
                            c=color,
                            ecolor=color)
            fb_x = np.hstack([[j] * len(pows)
                              for j, (key, pows) in enumerate(powers.items())
                              if 'FB' in key])
            fb_y = np.hstack(
                [pows for key, pows in powers.items() if 'FB' in key]) / norm
            sns.regplot(x=fb_x,
                        y=fb_y,
                        ax=ax,
                        color=cm['FB'],
                        scatter=False,
                        truncate=True)

            ax1.set_xlim(0, t / fs)
            ax1.set_ylim(-40, 40)
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=70)
            ax.set_xticks(range(len(powers)))
            ax.set_xticklabels(powers.keys())
            ax.set_ylim(0, 3)
            ax.set_xlim(-1, len(powers))
            ax1.set_title('Day {}'.format(j_s + 1))
    return fg
Exemplo n.º 6
0
            lengths_buffer = []
            for j, y in enumerate(mask_copy):
                if y:
                    lengths_buffer.append(x[j])
                elif len(lengths_buffer) > 0:
                    lengths.append(np.mean(lengths_buffer))
                    lengths_buffer = []
            print(lengths)
            return np.array(lengths)



        with h5py.File('{}\\{}\\{}'.format(dir_, experiment, 'experiment_data.h5')) as f:
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            if reject:
                rejections = load_rejections(f, reject_alpha=True)[0]
            else:
                rejections = None
            spatial = f['protocol15/signals_stats/left/spatial_filter'][:]
            plot_topomap(spatial, ch_names_to_2d_pos(channels), axes=plt.gca(), show=False)
            plt.savefig('alphaS{}_Day{}_spatial_filter'.format(subj, day+1))
            mu_band = f['protocol15/signals_stats/left/bandpass'][:]
            #mu_band = (12, 13)
            max_gap = 1 / min(mu_band) * 2
            min_sate_duration = max_gap * 2
            raw = OrderedDict()
            signal = OrderedDict()
            for j, name in enumerate(p_names):
                x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejections)
                raw = add_data(raw, name, x, j)
                signal = add_data(signal, name, f['protocol{}/signals_data'.format(j + 1)][:], j)
Exemplo n.º 7
0
def plot_results(pilot_dir, subj, channel, alpha_band=(9, 14), theta_band=(3, 6), drop_channels=None, dc=False,
                 reject_alpha=True, normalize_by='opened'):
    drop_channels = drop_channels or []
    cm = get_colors()
    fg = plt.figure(figsize=(30, 6))
    for j_s, experiment in enumerate(subj):
        with h5py.File('{}\\{}\\{}'.format(pilot_dir, experiment, 'experiment_data.h5')) as f:
            rejections, top_alpha, top_ica = load_rejections(f, reject_alpha=reject_alpha)
            fs, channels, p_names = get_info(f, drop_channels)
            ch = channels.index(channel)
            #plt.plot(fft_filter(f['protocol6/raw_data'][:, ch], fs, band=(3, 35)))
            #plt.plot(fft_filter(np.dot(f['protocol6/raw_data'], rejections)[:, ch], fs, band=(3, 35)))
            #plt.show()
            #from scipy.signal import welch
            #plt.plot(*welch(f['protocol1/raw_data'][:60*500//2, channels.index('C3')], fs, nperseg=1000))
            #plt.plot(*welch(f['protocol1/raw_data'][60*500//2:, channels.index('C3')], fs, nperseg=1000))

            #plt.plot(*welch(f['protocol2/raw_data'][:30*500//2, channels.index('C3')], fs, nperseg=1000))
            #plt.plot(*welch(f['protocol2/raw_data'][30*500//2:, channels.index('C3')], fs, nperseg=1000))
            #plt.legend(['Close', 'Open', 'Left', 'Right'])
            #plt.show()

            # collect powers
            powers = OrderedDict()
            raw = OrderedDict()
            alpha = OrderedDict()
            pow_theta = []
            for j, name in enumerate(p_names):
                pow, alpha_x, x = get_protocol_power(f, j, fs, rejections, ch, alpha_band, dc=dc)
                if 'FB' in name:
                    pow_theta.append(get_protocol_power(f, j, fs, rejections, ch, theta_band, dc=dc)[0].mean())
                powers = add_data(powers, name, pow, j)
                raw = add_data(raw, name, x, j)
                alpha = add_data(alpha, name, alpha_x, j)

            # plot rejections
            n_tops = top_ica.shape[1] + top_alpha.shape[1]
            for j_t in range(top_ica.shape[1]):
                ax = fg.add_subplot(4, n_tops * len(subj), n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1)
                ax.set_xlabel('ICA{}'.format(j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [label for label in labels if label not in drop_channels]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_ica[:, j_t], pos=pos, axes=ax, show=False)
            for j_t in range(top_alpha.shape[1]):
                ax = fg.add_subplot(4, n_tops * len(subj),
                                    n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1 + top_ica.shape[1])
                ax.set_xlabel('CSP{}'.format(j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [label for label in labels if label not in drop_channels]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_alpha[:, j_t], pos=pos, axes=ax, show=False)

            # plot powers
            if normalize_by == 'opened':
                norm = powers['1. Opened'].mean()
            elif normalize_by == 'beta':
                norm = np.mean(pow_theta)
            else:
                print('WARNING: norm = 1')
            print('norm', norm)

            ax1 = fg.add_subplot(3, len(subj), j_s + 1)
            ax = fg.add_subplot(3, len(subj), j_s + len(subj) + 1)
            t = 0
            for j_p, ((name, pow), (name, x)) in enumerate(zip(powers.items(), raw.items())):
                if name == '2228. FB':
                    from scipy.signal import periodogram
                    fff = plt.figure()
                    fff.gca().plot(*periodogram(x, fs, nfft=fs * 3), c=cm[name.split()[1]])
                    plt.xlim(0, 80)
                    plt.ylim(0, 3e-11)
                    plt.show()
                print(name)
                time = np.arange(t, t + len(x)) / fs
                color = cm[''.join([i for i in name.split()[1] if not i.isdigit()])]
                ax1.plot(time, fft_filter(x, fs, (2, 45)), c=color, alpha=0.4)
                ax1.plot(time, alpha[name], c=color)
                t += len(x)
                ax.plot([j_p], [pow.mean() / norm], 'o', c=color, markersize=10)
                ax.errorbar([j_p], [pow.mean() / norm], yerr=pow.std() / norm, c=color, ecolor=color)
            fb_x = np.hstack([[j] * len(pows) for j, (key, pows) in enumerate(powers.items()) if 'FB' in key])
            fb_y = np.hstack([pows for key, pows in powers.items() if 'FB' in key]) / norm
            sns.regplot(x=fb_x, y=fb_y, ax=ax, color=cm['FB'], scatter=False, truncate=True)

            ax1.set_xlim(0, t / fs)
            ax1.set_ylim(-40, 40)
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=70)
            ax.set_xticks(range(len(powers)))
            ax.set_xticklabels(powers.keys())
            ax.set_ylim(0, 3)
            ax.set_xlim(-1, len(powers))
            ax1.set_title('Day {}'.format(j_s + 1))
    return fg
Exemplo n.º 8
0
from pynfb.protocols.ssd.topomap_selector_ica import ICADialog
from PyQt5 import QtGui, QtWidgets

a = QtWidgets.QApplication([])

fig1, axes1 = plt.subplots(ncols=3)

for subj in range(3,4):
    for day in range(2,5):
        experiments = settings['subjects'][subj]
        experiment = experiments[day]
        reject = True
        with h5py.File('{}\\{}\\{}'.format(settings['dir'], experiment, 'experiment_data.h5')) as f:
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            rejection, alpha, ica = load_rejections(f, reject_alpha=False)
            raw_before = OrderedDict()
            raw_after = OrderedDict()
            for j, name in enumerate(p_names):
                if j < 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
                    raw_before = add_data_simple(raw_before, name, x)
                elif j > len(p_names) - 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
                    print(x.shape)
                    raw_after = add_data_simple(raw_after, name, x)



        xx = np.concatenate([raw_before['Opened'], raw_before['Baseline'], raw_after['Baseline'], raw_before['Left'], raw_before['Right'],  raw_after['Left'], raw_after['Right'], ])
        #xx[:, channels.index('C3')] = xx[:, channels.index('C3')]