Example #1
1
    def update_figure(self, data, pos=None, names=None, show_names=None, show_colorbar=True, central_text=None,
                      right_bottom_text=None, show_not_found_symbol=False, montage=None):
        if montage is None:
            if pos is None:
                pos = ch_names_to_2d_pos(names)
        else:
            pos = montage.get_pos('EEG')
            names = montage.get_names('EEG')
        data = np.array(data)
        self.axes.clear()
        if self.colorbar:
            self.colorbar.remove()
        if show_names is None:
            show_names = ['O1', 'O2', 'CZ', 'T3', 'T4', 'T7', 'T8', 'FP1', 'FP2']
        show_names = [name.upper() for name in show_names]
        mask = np.array([name.upper() in show_names for name in names]) if names else None
        v_min, v_max = None, None
        if (data == data[0]).all():
            data[0] += 0.1
            data[1] -= 0.1
            v_min, v_max = -1, 1
        a, b = plot_topomap(data, pos, axes=self.axes, show=False, contours=0, names=names, show_names=True,
                            mask=mask,
                            mask_params=dict(marker='o',
                                             markerfacecolor='w',
                                             markeredgecolor='w',
                                             linewidth=0,
                                             markersize=3),
                            vmin=v_min,
                            vmax=v_max)
        if central_text is not None:
            self.axes.text(0, 0, central_text, horizontalalignment='center', verticalalignment='center')

        if right_bottom_text is not None:
            self.axes.text(-0.65, 0.65, right_bottom_text, horizontalalignment='left', verticalalignment='top')

        if show_not_found_symbol:
            self.axes.text(0, 0, '/', horizontalalignment='center', verticalalignment='center')
            self.axes.text(0, 0, 'O', size=10, horizontalalignment='center', verticalalignment='center')

        if show_colorbar:
            self.colorbar = self.fig.colorbar(a, orientation='horizontal', ax=self.axes)
            self.colorbar.ax.tick_params(labelsize=6)
            self.colorbar.ax.set_xticklabels(self.colorbar.ax.get_xticklabels(), rotation=90)
        self.draw()
Example #2
0
    def update_figure(self, data, pos=None, names=None, show_names=None, show_colorbar=True, central_text=None,
                      right_bottom_text=None, show_not_found_symbol=False, montage=None):
        if montage is None:
            if pos is None:
                pos = ch_names_to_2d_pos(names)
        else:
            pos = montage.get_pos('EEG')
            names = montage.get_names('EEG')
        data = np.array(data)
        self.axes.clear()
        if self.colorbar:
            self.colorbar.remove()
        if show_names is None:
            show_names = ['O1', 'O2', 'CZ', 'T3', 'T4', 'T7', 'T8', 'FP1', 'FP2']
        show_names = [name.upper() for name in show_names]
        mask = np.array([name.upper() in show_names for name in names]) if names else None
        v_min, v_max = None, None
        if (data == data[0]).all():
            data[0] += 0.1
            data[1] -= 0.1
            v_min, v_max = -1, 1
        a, b = plot_topomap(data, pos, axes=self.axes, show=False, contours=0, names=names, show_names=True,
                            mask=mask,
                            mask_params=dict(marker='o',
                                             markerfacecolor='w',
                                             markeredgecolor='w',
                                             linewidth=0,
                                             markersize=3),
                            vmin=v_min,
                            vmax=v_max)
        if central_text is not None:
            self.axes.text(0, 0, central_text, horizontalalignment='center', verticalalignment='center')

        if right_bottom_text is not None:
            self.axes.text(-0.65, 0.65, right_bottom_text, horizontalalignment='left', verticalalignment='top')

        if show_not_found_symbol:
            self.axes.text(0, 0, '/', horizontalalignment='center', verticalalignment='center')
            self.axes.text(0, 0, 'O', size=10, horizontalalignment='center', verticalalignment='center')

        if show_colorbar:
            self.colorbar = self.fig.colorbar(a, orientation='horizontal', ax=self.axes)
            self.colorbar.ax.tick_params(labelsize=6)
            self.colorbar.ax.set_xticklabels(self.colorbar.ax.get_xticklabels(), rotation=90)
        self.draw()
Example #3
0
                elif j > len(p_names) - 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs,
                                rejection if reject else None)
                    print(x.shape)
                    raw_after = add_data_simple(raw_after, name, x)

        xx = np.concatenate([
            raw_before['Opened'],
            raw_before['Baseline'],
            raw_after['Baseline'],
            raw_before['Left'],
            raw_before['Right'],
            raw_after['Left'],
            raw_after['Right'],
        ])
        #xx[:, channels.index('C3')] = xx[:, channels.index('C3')]
        rejection, spatial, topography, unmixing_matrix, bandpass, _ = ICADialog.get_rejection(
            xx, channels, fs, mode='csp', states=None)
        from mne.viz import plot_topomap
        print(spatial)
        fig, axes = plt.subplots(ncols=rejection.topographies.shape[1])
        if not isinstance(axes, type(axes1)):
            axes = [axes]
        for ax, top in zip(axes, rejection.topographies.T):
            plot_topomap(top,
                         ch_names_to_2d_pos(channels),
                         axes=ax,
                         show=False)
        fig.savefig('csp_S{}_D{}.png'.format(subj, day + 1))
        fig.show()
Example #4
0
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            rejection, alpha, ica = load_rejections(f, reject_alpha=False)
            raw_before = OrderedDict()
            raw_after = OrderedDict()
            for j, name in enumerate(p_names):
                if j < 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
                    raw_before = add_data_simple(raw_before, name, x)
                elif j > len(p_names) - 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
                    print(x.shape)
                    raw_after = add_data_simple(raw_after, name, x)



        xx = np.concatenate([raw_before['Opened'], raw_before['Baseline'], raw_after['Baseline'], raw_before['Left'], raw_before['Right'],  raw_after['Left'], raw_after['Right'], ])
        #xx[:, channels.index('C3')] = xx[:, channels.index('C3')]
        rejection, spatial, topography, unmixing_matrix, bandpass, _ = ICADialog.get_rejection(xx, channels, fs, mode='csp', states=None)
        from mne.viz import plot_topomap
        print(spatial)
        fig, axes = plt.subplots(ncols=rejection.topographies.shape[1])
        if not isinstance(axes, type(axes1)) :
            axes = [axes]
        for ax, top in zip(axes, rejection.topographies.T):
            plot_topomap(top, ch_names_to_2d_pos(channels), axes=ax, show=False)
        fig.savefig('csp_S{}_D{}.png'.format(subj, day+1))
        fig.show()



Example #5
0
    def __init__(self,
                 current_protocol,
                 protocols,
                 signals,
                 n_signals=1,
                 parent=None,
                 n_channels=32,
                 max_protocol_n_samples=None,
                 experiment=None,
                 freq=500,
                 plot_raw_flag=True,
                 plot_signals_flag=True,
                 plot_source_space_flag=False,
                 show_subject_window=True,
                 channels_labels=None,
                 subject_backend_expyriment=False):
        super(MainWindow, self).__init__(parent)

        # Which windows to draw:
        self.plot_source_space_flag = plot_source_space_flag
        self.show_subject_window = show_subject_window

        # status info
        self.status = PlayerLineInfo([p.name for p in protocols],
                                     [[p.duration for p in protocols]])

        self.source_freq = freq
        self.experiment = experiment
        self.signals = signals

        # player panel
        self.player_panel = PlayerButtonsWidget(parent=self)
        self.player_panel.restart.clicked.connect(self.restart_experiment)
        for signal in signals:
            self.player_panel.start.clicked.connect(signal.reset_statistic_acc)
        self.player_panel.start.clicked.connect(self.update_first_status)
        self._first_time_start_press = True

        # timer label
        self.timer_label = QtGui.QLabel('tf')

        # signals viewer
        self.signals_viewer = DerivedSignalViewer(
            freq, [signal.name for signal in signals])

        # raw data viewer
        self.raw_viewer = RawSignalViewer(freq, channels_labels)
        self.n_channels = n_channels
        self.n_samples = 2000

        self.plot_raw_checkbox = QtGui.QCheckBox('plot raw')
        self.plot_raw_checkbox.setChecked(plot_raw_flag)
        self.plot_signals_checkbox = QtGui.QCheckBox('plot signals')
        self.plot_signals_checkbox.setChecked(plot_signals_flag)
        self.autoscale_raw_chekbox = QtGui.QCheckBox('autoscale')
        self.autoscale_raw_chekbox.setChecked(True)

        # topomaper
        pos = ch_names_to_2d_pos(channels_labels)
        #self.topomaper = TopomapWidget(pos)

        # dc_blocker
        self.dc_blocker = DCBlocker()

        # main window layout
        layout = pg.LayoutWidget(self)
        layout.addWidget(self.signals_viewer, 0, 0, 1, 3)
        layout.addWidget(self.plot_raw_checkbox, 1, 0, 1, 1)
        layout.addWidget(self.plot_signals_checkbox, 1, 2, 1, 1)
        layout.addWidget(self.autoscale_raw_chekbox, 1, 1, 1, 1)
        layout.addWidget(self.raw_viewer, 2, 0, 1, 3)
        layout.addWidget(self.player_panel, 3, 0, 1, 1)
        layout.addWidget(self.timer_label, 3, 1, 1, 1)
        #layout.addWidget(self.topomaper, 3, 2, 1, 1)
        layout.addWidget(self.status, 4, 0, 1, 3)
        layout.layout.setRowStretch(0, 2)
        layout.layout.setRowStretch(2, 2)
        self.setCentralWidget(layout)

        # main window settings
        self.resize(800, 600)
        self.show()

        # subject window
        if show_subject_window:
            if not subject_backend_expyriment:
                self.subject_window = SubjectWindow(self, current_protocol)
                self.subject_window.show()
            else:
                self.subject_window = ExpyrimentSubjectWindow(
                    self, current_protocol)
            self._subject_window_want_to_close = False
        else:
            self.subject_window = None
            self._subject_window_want_to_close = None

        # Source space window
        if plot_source_space_flag:
            source_space_protocol = SourceSpaceRecontructor(signals)
            self.source_space_window = SourceSpaceWindow(
                self, source_space_protocol)
            self.source_space_window.show()

        # time counter
        self.time_counter = 0
        self.time_counter1 = 0
        self.t0 = time.time()
        self.t = self.t0
Example #6
0
                        powers['{}. Closed'.format(j+1)] = pow[:len(pow)//2]
                        powers['{}. Opened'.format(j+1)] = pow[len(pow)//2:]
                    elif name == 'Rotate':
                        powers['{}. Right'.format(j+1)] = pow[:len(pow)//2]
                        powers['{}. Left'.format(j+1)] = pow[len(pow)//2:]
                    else:
                        powers['{}. {}'.format(j+1, name)] = pow


                # plot rejections
                for j_t in range(top_ica.shape[1]):
                    ax = fg.add_subplot(5, top_ica.shape[1]*len(subj), top_ica.shape[1]*len(subj)*3 + top_ica.shape[1]*j_s + j_t + 1)
                    ax.set_xlabel('ICA{}'.format(j_t+1))
                    labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                    channels = [label for label in labels if label not in drop_channels]
                    pos = ch_names_to_2d_pos(channels)
                    plot_topomap(data=top_ica[:, j_t], pos=pos, axes=ax, show=False)
                for j_t in range(top_alpha.shape[1]):
                    ax = fg.add_subplot(5, top_alpha.shape[1]*len(subj), top_alpha.shape[1]*len(subj)*4 + top_alpha.shape[1]*j_s + j_t + 1)
                    ax.set_xlabel('CSP{}'.format(j_t+1))
                    labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                    channels = [label for label in labels if label not in drop_channels]
                    pos = ch_names_to_2d_pos(channels)
                    plot_topomap(data=top_alpha[:, j_t], pos=pos, axes=ax, show=False)


                # plot powers
                norm = powers['{}. Baseline'.format(p_names.index('Baseline') + 1)].mean()
                #norm = np.mean(pow_theta)
                print('norm', norm)
                ax = fg.add_subplot(2, len(subj), j_s + 1)
Example #7
0
                elif len(lengths_buffer) > 0:
                    lengths.append(np.mean(lengths_buffer))
                    lengths_buffer = []
            print(lengths)
            return np.array(lengths)



        with h5py.File('{}\\{}\\{}'.format(dir_, experiment, 'experiment_data.h5')) as f:
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            if reject:
                rejections = load_rejections(f, reject_alpha=True)[0]
            else:
                rejections = None
            spatial = f['protocol15/signals_stats/left/spatial_filter'][:]
            plot_topomap(spatial, ch_names_to_2d_pos(channels), axes=plt.gca(), show=False)
            plt.savefig('alphaS{}_Day{}_spatial_filter'.format(subj, day+1))
            mu_band = f['protocol15/signals_stats/left/bandpass'][:]
            #mu_band = (12, 13)
            max_gap = 1 / min(mu_band) * 2
            min_sate_duration = max_gap * 2
            raw = OrderedDict()
            signal = OrderedDict()
            for j, name in enumerate(p_names):
                x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejections)
                raw = add_data(raw, name, x, j)
                signal = add_data(signal, name, f['protocol{}/signals_data'.format(j + 1)][:], j)

        del raw[list(raw.keys())[-1]]
        # make csp:
        if run_ica:
def plot_results(pilot_dir,
                 subj,
                 channel,
                 alpha_band=(9, 14),
                 theta_band=(3, 6),
                 drop_channels=None,
                 dc=False,
                 reject_alpha=True,
                 normalize_by='opened'):
    drop_channels = drop_channels or []
    cm = get_colors()
    fg = plt.figure(figsize=(30, 6))
    for j_s, experiment in enumerate(subj):
        with h5py.File('{}\\{}\\{}'.format(pilot_dir, experiment,
                                           'experiment_data.h5')) as f:
            rejections, top_alpha, top_ica = load_rejections(
                f, reject_alpha=reject_alpha)
            fs, channels, p_names = get_info(f, drop_channels)
            ch = channels.index(channel)
            #plt.plot(fft_filter(f['protocol6/raw_data'][:, ch], fs, band=(3, 35)))
            #plt.plot(fft_filter(np.dot(f['protocol6/raw_data'], rejections)[:, ch], fs, band=(3, 35)))
            #plt.show()
            #from scipy.signal import welch
            #plt.plot(*welch(f['protocol1/raw_data'][:60*500//2, channels.index('C3')], fs, nperseg=1000))
            #plt.plot(*welch(f['protocol1/raw_data'][60*500//2:, channels.index('C3')], fs, nperseg=1000))

            #plt.plot(*welch(f['protocol2/raw_data'][:30*500//2, channels.index('C3')], fs, nperseg=1000))
            #plt.plot(*welch(f['protocol2/raw_data'][30*500//2:, channels.index('C3')], fs, nperseg=1000))
            #plt.legend(['Close', 'Open', 'Left', 'Right'])
            #plt.show()

            # collect powers
            powers = OrderedDict()
            raw = OrderedDict()
            alpha = OrderedDict()
            pow_theta = []
            for j, name in enumerate(p_names):
                pow, alpha_x, x = get_protocol_power(f,
                                                     j,
                                                     fs,
                                                     rejections,
                                                     ch,
                                                     alpha_band,
                                                     dc=dc)
                if 'FB' in name:
                    pow_theta.append(
                        get_protocol_power(f,
                                           j,
                                           fs,
                                           rejections,
                                           ch,
                                           theta_band,
                                           dc=dc)[0].mean())
                powers = add_data(powers, name, pow, j)
                raw = add_data(raw, name, x, j)
                alpha = add_data(alpha, name, alpha_x, j)

            # plot rejections
            n_tops = top_ica.shape[1] + top_alpha.shape[1]
            for j_t in range(top_ica.shape[1]):
                ax = fg.add_subplot(
                    4, n_tops * len(subj),
                    n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1)
                ax.set_xlabel('ICA{}'.format(j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [
                    label for label in labels if label not in drop_channels
                ]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_ica[:, j_t],
                             pos=pos,
                             axes=ax,
                             show=False)
            for j_t in range(top_alpha.shape[1]):
                ax = fg.add_subplot(
                    4, n_tops * len(subj),
                    n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1 +
                    top_ica.shape[1])
                ax.set_xlabel('CSP{}'.format(j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [
                    label for label in labels if label not in drop_channels
                ]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_alpha[:, j_t],
                             pos=pos,
                             axes=ax,
                             show=False)

            # plot powers
            if normalize_by == 'opened':
                norm = powers['1. Opened'].mean()
            elif normalize_by == 'beta':
                norm = np.mean(pow_theta)
            else:
                print('WARNING: norm = 1')
            print('norm', norm)

            ax1 = fg.add_subplot(3, len(subj), j_s + 1)
            ax = fg.add_subplot(3, len(subj), j_s + len(subj) + 1)
            t = 0
            for j_p, ((name, pow),
                      (name, x)) in enumerate(zip(powers.items(),
                                                  raw.items())):
                if name == '2228. FB':
                    from scipy.signal import periodogram
                    fff = plt.figure()
                    fff.gca().plot(*periodogram(x, fs, nfft=fs * 3),
                                   c=cm[name.split()[1]])
                    plt.xlim(0, 80)
                    plt.ylim(0, 3e-11)
                    plt.show()
                print(name)
                time = np.arange(t, t + len(x)) / fs
                color = cm[''.join(
                    [i for i in name.split()[1] if not i.isdigit()])]
                ax1.plot(time, fft_filter(x, fs, (2, 45)), c=color, alpha=0.4)
                ax1.plot(time, alpha[name], c=color)
                t += len(x)
                ax.plot([j_p], [pow.mean() / norm],
                        'o',
                        c=color,
                        markersize=10)
                ax.errorbar([j_p], [pow.mean() / norm],
                            yerr=pow.std() / norm,
                            c=color,
                            ecolor=color)
            fb_x = np.hstack([[j] * len(pows)
                              for j, (key, pows) in enumerate(powers.items())
                              if 'FB' in key])
            fb_y = np.hstack(
                [pows for key, pows in powers.items() if 'FB' in key]) / norm
            sns.regplot(x=fb_x,
                        y=fb_y,
                        ax=ax,
                        color=cm['FB'],
                        scatter=False,
                        truncate=True)

            ax1.set_xlim(0, t / fs)
            ax1.set_ylim(-40, 40)
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=70)
            ax.set_xticks(range(len(powers)))
            ax.set_xticklabels(powers.keys())
            ax.set_ylim(0, 3)
            ax.set_xlim(-1, len(powers))
            ax1.set_title('Day {}'.format(j_s + 1))
    return fg
Example #9
0
                elif len(lengths_buffer) > 0:
                    lengths.append(np.mean(lengths_buffer))
                    lengths_buffer = []
            print(lengths)
            return np.array(lengths)



        with h5py.File('{}\\{}\\{}'.format(dir_, experiment, 'experiment_data.h5')) as f:
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            if reject:
                rejections = load_rejections(f, reject_alpha=True)[0]
            else:
                rejections = None
            spatial = f['protocol15/signals_stats/left/spatial_filter'][:]
            plot_topomap(spatial, ch_names_to_2d_pos(channels), axes=plt.gca(), show=False)
            plt.savefig('alphaS{}_Day{}_spatial_filter'.format(subj, day+1))
            mu_band = f['protocol15/signals_stats/left/bandpass'][:]
            #mu_band = (12, 13)
            max_gap = 1 / min(mu_band) * 2
            min_sate_duration = max_gap * 2
            raw = OrderedDict()
            signal = OrderedDict()
            for j, name in enumerate(p_names):
                x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejections)
                raw = add_data(raw, name, x, j)
                signal = add_data(signal, name, f['protocol{}/signals_data'.format(j + 1)][:], j)

        del raw[list(raw.keys())[-1]]
        # make csp:
        if run_ica:
Example #10
0
            labels_, fs_ = get_lsl_info_from_xml(f['stream_info.xml'][0])
            print(labels_)
            channels = [label for label in labels_ if label not in ['A1', 'A2', 'AUX']]
            print(labels_)

            pz_index = channels.index('Pz')
            raw = raw - np.dot(raw[:, [pz_index]], np.ones((1, raw.shape[1])))

            del channels[pz_index]
            raw = raw[:, np.arange(raw.shape[1]) != pz_index]


            signal = DerivedSignal(ind=0, name='Signal', bandpass_low=9, bandpass_high=14,
                                     spatial_filter=np.array([0]), n_channels=raw.shape[1])
            w = SignalsSSDManager([signal], raw, ch_names_to_2d_pos(channels), channels, None, None, [], sampling_freq=fs_ )
            w.exec_()

            rejections = signal.rejections.get_list()
            new_rejections[experiment] = rejections
    with open(new_rejections_file, 'wb') as pkl:
        pickle.dump(new_rejections, pkl)
    del a
else:
    print('file exist')
    with open(new_rejections_file, 'rb') as handle:
        new_rejections = pickle.load(handle)

print(new_rejections)

Example #11
0
def plot_results(pilot_dir, subj, channel, alpha_band=(9, 14), theta_band=(3, 6), drop_channels=None, dc=False,
                 reject_alpha=True, normalize_by='opened'):
    drop_channels = drop_channels or []
    cm = get_colors()
    fg = plt.figure(figsize=(30, 6))
    for j_s, experiment in enumerate(subj):
        with h5py.File('{}\\{}\\{}'.format(pilot_dir, experiment, 'experiment_data.h5')) as f:
            rejections, top_alpha, top_ica = load_rejections(f, reject_alpha=reject_alpha)
            fs, channels, p_names = get_info(f, drop_channels)
            ch = channels.index(channel)
            #plt.plot(fft_filter(f['protocol6/raw_data'][:, ch], fs, band=(3, 35)))
            #plt.plot(fft_filter(np.dot(f['protocol6/raw_data'], rejections)[:, ch], fs, band=(3, 35)))
            #plt.show()
            #from scipy.signal import welch
            #plt.plot(*welch(f['protocol1/raw_data'][:60*500//2, channels.index('C3')], fs, nperseg=1000))
            #plt.plot(*welch(f['protocol1/raw_data'][60*500//2:, channels.index('C3')], fs, nperseg=1000))

            #plt.plot(*welch(f['protocol2/raw_data'][:30*500//2, channels.index('C3')], fs, nperseg=1000))
            #plt.plot(*welch(f['protocol2/raw_data'][30*500//2:, channels.index('C3')], fs, nperseg=1000))
            #plt.legend(['Close', 'Open', 'Left', 'Right'])
            #plt.show()

            # collect powers
            powers = OrderedDict()
            raw = OrderedDict()
            alpha = OrderedDict()
            pow_theta = []
            for j, name in enumerate(p_names):
                pow, alpha_x, x = get_protocol_power(f, j, fs, rejections, ch, alpha_band, dc=dc)
                if 'FB' in name:
                    pow_theta.append(get_protocol_power(f, j, fs, rejections, ch, theta_band, dc=dc)[0].mean())
                powers = add_data(powers, name, pow, j)
                raw = add_data(raw, name, x, j)
                alpha = add_data(alpha, name, alpha_x, j)

            # plot rejections
            n_tops = top_ica.shape[1] + top_alpha.shape[1]
            for j_t in range(top_ica.shape[1]):
                ax = fg.add_subplot(4, n_tops * len(subj), n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1)
                ax.set_xlabel('ICA{}'.format(j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [label for label in labels if label not in drop_channels]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_ica[:, j_t], pos=pos, axes=ax, show=False)
            for j_t in range(top_alpha.shape[1]):
                ax = fg.add_subplot(4, n_tops * len(subj),
                                    n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1 + top_ica.shape[1])
                ax.set_xlabel('CSP{}'.format(j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [label for label in labels if label not in drop_channels]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_alpha[:, j_t], pos=pos, axes=ax, show=False)

            # plot powers
            if normalize_by == 'opened':
                norm = powers['1. Opened'].mean()
            elif normalize_by == 'beta':
                norm = np.mean(pow_theta)
            else:
                print('WARNING: norm = 1')
            print('norm', norm)

            ax1 = fg.add_subplot(3, len(subj), j_s + 1)
            ax = fg.add_subplot(3, len(subj), j_s + len(subj) + 1)
            t = 0
            for j_p, ((name, pow), (name, x)) in enumerate(zip(powers.items(), raw.items())):
                if name == '2228. FB':
                    from scipy.signal import periodogram
                    fff = plt.figure()
                    fff.gca().plot(*periodogram(x, fs, nfft=fs * 3), c=cm[name.split()[1]])
                    plt.xlim(0, 80)
                    plt.ylim(0, 3e-11)
                    plt.show()
                print(name)
                time = np.arange(t, t + len(x)) / fs
                color = cm[''.join([i for i in name.split()[1] if not i.isdigit()])]
                ax1.plot(time, fft_filter(x, fs, (2, 45)), c=color, alpha=0.4)
                ax1.plot(time, alpha[name], c=color)
                t += len(x)
                ax.plot([j_p], [pow.mean() / norm], 'o', c=color, markersize=10)
                ax.errorbar([j_p], [pow.mean() / norm], yerr=pow.std() / norm, c=color, ecolor=color)
            fb_x = np.hstack([[j] * len(pows) for j, (key, pows) in enumerate(powers.items()) if 'FB' in key])
            fb_y = np.hstack([pows for key, pows in powers.items() if 'FB' in key]) / norm
            sns.regplot(x=fb_x, y=fb_y, ax=ax, color=cm['FB'], scatter=False, truncate=True)

            ax1.set_xlim(0, t / fs)
            ax1.set_ylim(-40, 40)
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=70)
            ax.set_xticks(range(len(powers)))
            ax.set_xticklabels(powers.keys())
            ax.set_ylim(0, 3)
            ax.set_xlim(-1, len(powers))
            ax1.set_title('Day {}'.format(j_s + 1))
    return fg
Example #12
0
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            rejection, alpha, ica = load_rejections(f, reject_alpha=False)
            raw_before = OrderedDict()
            raw_after = OrderedDict()
            for j, name in enumerate(p_names):
                if j < 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
                    raw_before = add_data_simple(raw_before, name, x)
                elif j > len(p_names) - 3:
                    x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
                    print(x.shape)
                    raw_after = add_data_simple(raw_after, name, x)



        xx = np.concatenate([raw_before['Opened'], raw_before['Baseline'], raw_after['Baseline'], raw_before['Left'], raw_before['Right'],  raw_after['Left'], raw_after['Right'], ])
        #xx[:, channels.index('C3')] = xx[:, channels.index('C3')]
        rejection, spatial, topography, unmixing_matrix, bandpass, _ = ICADialog.get_rejection(xx, channels, fs, mode='csp', states=None)
        from mne.viz import plot_topomap
        print(spatial)
        fig, axes = plt.subplots(ncols=rejection.topographies.shape[1])
        if not isinstance(axes, type(axes1)) :
            axes = [axes]
        for ax, top in zip(axes, rejection.topographies.T):
            plot_topomap(top, ch_names_to_2d_pos(channels), axes=ax, show=False)
        fig.savefig('csp_S{}_D{}.png'.format(subj, day+1))
        fig.show()



Example #13
0
    with h5py.File('{}\\{}\\{}'.format(settings['dir'], experiment, 'experiment_data.h5')) as f:
        fs, channels, p_names = get_info(f, settings['drop_channels'])
        rejection, alpha, ica = None, None, None#load_rejections(f, reject_alpha=True)
        odict = OrderedDict()
        for j, name in enumerate(p_names):
            x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None)
            odict = add_data_simple(odict, name, x)
        raw[names[j_experiment]] = odict


    for j, key in enumerate(state_plot):

        f, Pxx = welch(raw[names[j_experiment]][key], fs, nperseg=2048, axis=0)
        #axes[j].semilogy(f, Pxx, alpha=1, c=cm[j_experiment*2+3])
        ax = axes[j, j_experiment]
        a, b = plot_topomap(np.log10(Pxx[np.argmin(np.abs(f-peak)), :]), ch_names_to_2d_pos(channels), cmap='Reds',
                            axes=ax, show=False, vmax=-10.5, vmin=-13)
        if j_experiment == 0:
            ax.set_ylabel(key)
        if j == len(state_plot)-1:
            ax.set_xlabel(names[j_experiment])

        #axes[j].set_xlim(0, 250)
        #axes[j].set_ylim(1e-19, 5e-10)
            #x_plot = np.abs(hilbert(fft_filter(raw_before[key][:, channels.index(ch)], fs)))
            #leg.append('P={:.3f}, D={:.3f}s'.format(Pxx[(f > 9) & (f < 14)].mean(), sum((x_plot > 5)) / fs/2))
        #axes[j].legend(leg)


fig2.colorbar(a)
plt.show()