def update_figure(self, data, pos=None, names=None, show_names=None, show_colorbar=True, central_text=None, right_bottom_text=None, show_not_found_symbol=False, montage=None): if montage is None: if pos is None: pos = ch_names_to_2d_pos(names) else: pos = montage.get_pos('EEG') names = montage.get_names('EEG') data = np.array(data) self.axes.clear() if self.colorbar: self.colorbar.remove() if show_names is None: show_names = ['O1', 'O2', 'CZ', 'T3', 'T4', 'T7', 'T8', 'FP1', 'FP2'] show_names = [name.upper() for name in show_names] mask = np.array([name.upper() in show_names for name in names]) if names else None v_min, v_max = None, None if (data == data[0]).all(): data[0] += 0.1 data[1] -= 0.1 v_min, v_max = -1, 1 a, b = plot_topomap(data, pos, axes=self.axes, show=False, contours=0, names=names, show_names=True, mask=mask, mask_params=dict(marker='o', markerfacecolor='w', markeredgecolor='w', linewidth=0, markersize=3), vmin=v_min, vmax=v_max) if central_text is not None: self.axes.text(0, 0, central_text, horizontalalignment='center', verticalalignment='center') if right_bottom_text is not None: self.axes.text(-0.65, 0.65, right_bottom_text, horizontalalignment='left', verticalalignment='top') if show_not_found_symbol: self.axes.text(0, 0, '/', horizontalalignment='center', verticalalignment='center') self.axes.text(0, 0, 'O', size=10, horizontalalignment='center', verticalalignment='center') if show_colorbar: self.colorbar = self.fig.colorbar(a, orientation='horizontal', ax=self.axes) self.colorbar.ax.tick_params(labelsize=6) self.colorbar.ax.set_xticklabels(self.colorbar.ax.get_xticklabels(), rotation=90) self.draw()
def update_figure(self, data, pos=None, names=None, show_names=None, show_colorbar=True, central_text=None, right_bottom_text=None, show_not_found_symbol=False, montage=None): if montage is None: if pos is None: pos = ch_names_to_2d_pos(names) else: pos = montage.get_pos('EEG') names = montage.get_names('EEG') data = np.array(data) self.axes.clear() if self.colorbar: self.colorbar.remove() if show_names is None: show_names = ['O1', 'O2', 'CZ', 'T3', 'T4', 'T7', 'T8', 'FP1', 'FP2'] show_names = [name.upper() for name in show_names] mask = np.array([name.upper() in show_names for name in names]) if names else None v_min, v_max = None, None if (data == data[0]).all(): data[0] += 0.1 data[1] -= 0.1 v_min, v_max = -1, 1 a, b = plot_topomap(data, pos, axes=self.axes, show=False, contours=0, names=names, show_names=True, mask=mask, mask_params=dict(marker='o', markerfacecolor='w', markeredgecolor='w', linewidth=0, markersize=3), vmin=v_min, vmax=v_max) if central_text is not None: self.axes.text(0, 0, central_text, horizontalalignment='center', verticalalignment='center') if right_bottom_text is not None: self.axes.text(-0.65, 0.65, right_bottom_text, horizontalalignment='left', verticalalignment='top') if show_not_found_symbol: self.axes.text(0, 0, '/', horizontalalignment='center', verticalalignment='center') self.axes.text(0, 0, 'O', size=10, horizontalalignment='center', verticalalignment='center') if show_colorbar: self.colorbar = self.fig.colorbar(a, orientation='horizontal', ax=self.axes) self.colorbar.ax.tick_params(labelsize=6) self.colorbar.ax.set_xticklabels(self.colorbar.ax.get_xticklabels(), rotation=90) self.draw()
elif j > len(p_names) - 3: x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None) print(x.shape) raw_after = add_data_simple(raw_after, name, x) xx = np.concatenate([ raw_before['Opened'], raw_before['Baseline'], raw_after['Baseline'], raw_before['Left'], raw_before['Right'], raw_after['Left'], raw_after['Right'], ]) #xx[:, channels.index('C3')] = xx[:, channels.index('C3')] rejection, spatial, topography, unmixing_matrix, bandpass, _ = ICADialog.get_rejection( xx, channels, fs, mode='csp', states=None) from mne.viz import plot_topomap print(spatial) fig, axes = plt.subplots(ncols=rejection.topographies.shape[1]) if not isinstance(axes, type(axes1)): axes = [axes] for ax, top in zip(axes, rejection.topographies.T): plot_topomap(top, ch_names_to_2d_pos(channels), axes=ax, show=False) fig.savefig('csp_S{}_D{}.png'.format(subj, day + 1)) fig.show()
fs, channels, p_names = get_info(f, settings['drop_channels']) rejection, alpha, ica = load_rejections(f, reject_alpha=False) raw_before = OrderedDict() raw_after = OrderedDict() for j, name in enumerate(p_names): if j < 3: x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None) raw_before = add_data_simple(raw_before, name, x) elif j > len(p_names) - 3: x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None) print(x.shape) raw_after = add_data_simple(raw_after, name, x) xx = np.concatenate([raw_before['Opened'], raw_before['Baseline'], raw_after['Baseline'], raw_before['Left'], raw_before['Right'], raw_after['Left'], raw_after['Right'], ]) #xx[:, channels.index('C3')] = xx[:, channels.index('C3')] rejection, spatial, topography, unmixing_matrix, bandpass, _ = ICADialog.get_rejection(xx, channels, fs, mode='csp', states=None) from mne.viz import plot_topomap print(spatial) fig, axes = plt.subplots(ncols=rejection.topographies.shape[1]) if not isinstance(axes, type(axes1)) : axes = [axes] for ax, top in zip(axes, rejection.topographies.T): plot_topomap(top, ch_names_to_2d_pos(channels), axes=ax, show=False) fig.savefig('csp_S{}_D{}.png'.format(subj, day+1)) fig.show()
def __init__(self, current_protocol, protocols, signals, n_signals=1, parent=None, n_channels=32, max_protocol_n_samples=None, experiment=None, freq=500, plot_raw_flag=True, plot_signals_flag=True, plot_source_space_flag=False, show_subject_window=True, channels_labels=None, subject_backend_expyriment=False): super(MainWindow, self).__init__(parent) # Which windows to draw: self.plot_source_space_flag = plot_source_space_flag self.show_subject_window = show_subject_window # status info self.status = PlayerLineInfo([p.name for p in protocols], [[p.duration for p in protocols]]) self.source_freq = freq self.experiment = experiment self.signals = signals # player panel self.player_panel = PlayerButtonsWidget(parent=self) self.player_panel.restart.clicked.connect(self.restart_experiment) for signal in signals: self.player_panel.start.clicked.connect(signal.reset_statistic_acc) self.player_panel.start.clicked.connect(self.update_first_status) self._first_time_start_press = True # timer label self.timer_label = QtGui.QLabel('tf') # signals viewer self.signals_viewer = DerivedSignalViewer( freq, [signal.name for signal in signals]) # raw data viewer self.raw_viewer = RawSignalViewer(freq, channels_labels) self.n_channels = n_channels self.n_samples = 2000 self.plot_raw_checkbox = QtGui.QCheckBox('plot raw') self.plot_raw_checkbox.setChecked(plot_raw_flag) self.plot_signals_checkbox = QtGui.QCheckBox('plot signals') self.plot_signals_checkbox.setChecked(plot_signals_flag) self.autoscale_raw_chekbox = QtGui.QCheckBox('autoscale') self.autoscale_raw_chekbox.setChecked(True) # topomaper pos = ch_names_to_2d_pos(channels_labels) #self.topomaper = TopomapWidget(pos) # dc_blocker self.dc_blocker = DCBlocker() # main window layout layout = pg.LayoutWidget(self) layout.addWidget(self.signals_viewer, 0, 0, 1, 3) layout.addWidget(self.plot_raw_checkbox, 1, 0, 1, 1) layout.addWidget(self.plot_signals_checkbox, 1, 2, 1, 1) layout.addWidget(self.autoscale_raw_chekbox, 1, 1, 1, 1) layout.addWidget(self.raw_viewer, 2, 0, 1, 3) layout.addWidget(self.player_panel, 3, 0, 1, 1) layout.addWidget(self.timer_label, 3, 1, 1, 1) #layout.addWidget(self.topomaper, 3, 2, 1, 1) layout.addWidget(self.status, 4, 0, 1, 3) layout.layout.setRowStretch(0, 2) layout.layout.setRowStretch(2, 2) self.setCentralWidget(layout) # main window settings self.resize(800, 600) self.show() # subject window if show_subject_window: if not subject_backend_expyriment: self.subject_window = SubjectWindow(self, current_protocol) self.subject_window.show() else: self.subject_window = ExpyrimentSubjectWindow( self, current_protocol) self._subject_window_want_to_close = False else: self.subject_window = None self._subject_window_want_to_close = None # Source space window if plot_source_space_flag: source_space_protocol = SourceSpaceRecontructor(signals) self.source_space_window = SourceSpaceWindow( self, source_space_protocol) self.source_space_window.show() # time counter self.time_counter = 0 self.time_counter1 = 0 self.t0 = time.time() self.t = self.t0
powers['{}. Closed'.format(j+1)] = pow[:len(pow)//2] powers['{}. Opened'.format(j+1)] = pow[len(pow)//2:] elif name == 'Rotate': powers['{}. Right'.format(j+1)] = pow[:len(pow)//2] powers['{}. Left'.format(j+1)] = pow[len(pow)//2:] else: powers['{}. {}'.format(j+1, name)] = pow # plot rejections for j_t in range(top_ica.shape[1]): ax = fg.add_subplot(5, top_ica.shape[1]*len(subj), top_ica.shape[1]*len(subj)*3 + top_ica.shape[1]*j_s + j_t + 1) ax.set_xlabel('ICA{}'.format(j_t+1)) labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0]) channels = [label for label in labels if label not in drop_channels] pos = ch_names_to_2d_pos(channels) plot_topomap(data=top_ica[:, j_t], pos=pos, axes=ax, show=False) for j_t in range(top_alpha.shape[1]): ax = fg.add_subplot(5, top_alpha.shape[1]*len(subj), top_alpha.shape[1]*len(subj)*4 + top_alpha.shape[1]*j_s + j_t + 1) ax.set_xlabel('CSP{}'.format(j_t+1)) labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0]) channels = [label for label in labels if label not in drop_channels] pos = ch_names_to_2d_pos(channels) plot_topomap(data=top_alpha[:, j_t], pos=pos, axes=ax, show=False) # plot powers norm = powers['{}. Baseline'.format(p_names.index('Baseline') + 1)].mean() #norm = np.mean(pow_theta) print('norm', norm) ax = fg.add_subplot(2, len(subj), j_s + 1)
elif len(lengths_buffer) > 0: lengths.append(np.mean(lengths_buffer)) lengths_buffer = [] print(lengths) return np.array(lengths) with h5py.File('{}\\{}\\{}'.format(dir_, experiment, 'experiment_data.h5')) as f: fs, channels, p_names = get_info(f, settings['drop_channels']) if reject: rejections = load_rejections(f, reject_alpha=True)[0] else: rejections = None spatial = f['protocol15/signals_stats/left/spatial_filter'][:] plot_topomap(spatial, ch_names_to_2d_pos(channels), axes=plt.gca(), show=False) plt.savefig('alphaS{}_Day{}_spatial_filter'.format(subj, day+1)) mu_band = f['protocol15/signals_stats/left/bandpass'][:] #mu_band = (12, 13) max_gap = 1 / min(mu_band) * 2 min_sate_duration = max_gap * 2 raw = OrderedDict() signal = OrderedDict() for j, name in enumerate(p_names): x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejections) raw = add_data(raw, name, x, j) signal = add_data(signal, name, f['protocol{}/signals_data'.format(j + 1)][:], j) del raw[list(raw.keys())[-1]] # make csp: if run_ica:
def plot_results(pilot_dir, subj, channel, alpha_band=(9, 14), theta_band=(3, 6), drop_channels=None, dc=False, reject_alpha=True, normalize_by='opened'): drop_channels = drop_channels or [] cm = get_colors() fg = plt.figure(figsize=(30, 6)) for j_s, experiment in enumerate(subj): with h5py.File('{}\\{}\\{}'.format(pilot_dir, experiment, 'experiment_data.h5')) as f: rejections, top_alpha, top_ica = load_rejections( f, reject_alpha=reject_alpha) fs, channels, p_names = get_info(f, drop_channels) ch = channels.index(channel) #plt.plot(fft_filter(f['protocol6/raw_data'][:, ch], fs, band=(3, 35))) #plt.plot(fft_filter(np.dot(f['protocol6/raw_data'], rejections)[:, ch], fs, band=(3, 35))) #plt.show() #from scipy.signal import welch #plt.plot(*welch(f['protocol1/raw_data'][:60*500//2, channels.index('C3')], fs, nperseg=1000)) #plt.plot(*welch(f['protocol1/raw_data'][60*500//2:, channels.index('C3')], fs, nperseg=1000)) #plt.plot(*welch(f['protocol2/raw_data'][:30*500//2, channels.index('C3')], fs, nperseg=1000)) #plt.plot(*welch(f['protocol2/raw_data'][30*500//2:, channels.index('C3')], fs, nperseg=1000)) #plt.legend(['Close', 'Open', 'Left', 'Right']) #plt.show() # collect powers powers = OrderedDict() raw = OrderedDict() alpha = OrderedDict() pow_theta = [] for j, name in enumerate(p_names): pow, alpha_x, x = get_protocol_power(f, j, fs, rejections, ch, alpha_band, dc=dc) if 'FB' in name: pow_theta.append( get_protocol_power(f, j, fs, rejections, ch, theta_band, dc=dc)[0].mean()) powers = add_data(powers, name, pow, j) raw = add_data(raw, name, x, j) alpha = add_data(alpha, name, alpha_x, j) # plot rejections n_tops = top_ica.shape[1] + top_alpha.shape[1] for j_t in range(top_ica.shape[1]): ax = fg.add_subplot( 4, n_tops * len(subj), n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1) ax.set_xlabel('ICA{}'.format(j_t + 1)) labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0]) channels = [ label for label in labels if label not in drop_channels ] pos = ch_names_to_2d_pos(channels) plot_topomap(data=top_ica[:, j_t], pos=pos, axes=ax, show=False) for j_t in range(top_alpha.shape[1]): ax = fg.add_subplot( 4, n_tops * len(subj), n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1 + top_ica.shape[1]) ax.set_xlabel('CSP{}'.format(j_t + 1)) labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0]) channels = [ label for label in labels if label not in drop_channels ] pos = ch_names_to_2d_pos(channels) plot_topomap(data=top_alpha[:, j_t], pos=pos, axes=ax, show=False) # plot powers if normalize_by == 'opened': norm = powers['1. Opened'].mean() elif normalize_by == 'beta': norm = np.mean(pow_theta) else: print('WARNING: norm = 1') print('norm', norm) ax1 = fg.add_subplot(3, len(subj), j_s + 1) ax = fg.add_subplot(3, len(subj), j_s + len(subj) + 1) t = 0 for j_p, ((name, pow), (name, x)) in enumerate(zip(powers.items(), raw.items())): if name == '2228. FB': from scipy.signal import periodogram fff = plt.figure() fff.gca().plot(*periodogram(x, fs, nfft=fs * 3), c=cm[name.split()[1]]) plt.xlim(0, 80) plt.ylim(0, 3e-11) plt.show() print(name) time = np.arange(t, t + len(x)) / fs color = cm[''.join( [i for i in name.split()[1] if not i.isdigit()])] ax1.plot(time, fft_filter(x, fs, (2, 45)), c=color, alpha=0.4) ax1.plot(time, alpha[name], c=color) t += len(x) ax.plot([j_p], [pow.mean() / norm], 'o', c=color, markersize=10) ax.errorbar([j_p], [pow.mean() / norm], yerr=pow.std() / norm, c=color, ecolor=color) fb_x = np.hstack([[j] * len(pows) for j, (key, pows) in enumerate(powers.items()) if 'FB' in key]) fb_y = np.hstack( [pows for key, pows in powers.items() if 'FB' in key]) / norm sns.regplot(x=fb_x, y=fb_y, ax=ax, color=cm['FB'], scatter=False, truncate=True) ax1.set_xlim(0, t / fs) ax1.set_ylim(-40, 40) plt.setp(ax.xaxis.get_majorticklabels(), rotation=70) ax.set_xticks(range(len(powers))) ax.set_xticklabels(powers.keys()) ax.set_ylim(0, 3) ax.set_xlim(-1, len(powers)) ax1.set_title('Day {}'.format(j_s + 1)) return fg
elif len(lengths_buffer) > 0: lengths.append(np.mean(lengths_buffer)) lengths_buffer = [] print(lengths) return np.array(lengths) with h5py.File('{}\\{}\\{}'.format(dir_, experiment, 'experiment_data.h5')) as f: fs, channels, p_names = get_info(f, settings['drop_channels']) if reject: rejections = load_rejections(f, reject_alpha=True)[0] else: rejections = None spatial = f['protocol15/signals_stats/left/spatial_filter'][:] plot_topomap(spatial, ch_names_to_2d_pos(channels), axes=plt.gca(), show=False) plt.savefig('alphaS{}_Day{}_spatial_filter'.format(subj, day+1)) mu_band = f['protocol15/signals_stats/left/bandpass'][:] #mu_band = (12, 13) max_gap = 1 / min(mu_band) * 2 min_sate_duration = max_gap * 2 raw = OrderedDict() signal = OrderedDict() for j, name in enumerate(p_names): x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejections) raw = add_data(raw, name, x, j) signal = add_data(signal, name, f['protocol{}/signals_data'.format(j + 1)][:], j) del raw[list(raw.keys())[-1]] # make csp: if run_ica:
labels_, fs_ = get_lsl_info_from_xml(f['stream_info.xml'][0]) print(labels_) channels = [label for label in labels_ if label not in ['A1', 'A2', 'AUX']] print(labels_) pz_index = channels.index('Pz') raw = raw - np.dot(raw[:, [pz_index]], np.ones((1, raw.shape[1]))) del channels[pz_index] raw = raw[:, np.arange(raw.shape[1]) != pz_index] signal = DerivedSignal(ind=0, name='Signal', bandpass_low=9, bandpass_high=14, spatial_filter=np.array([0]), n_channels=raw.shape[1]) w = SignalsSSDManager([signal], raw, ch_names_to_2d_pos(channels), channels, None, None, [], sampling_freq=fs_ ) w.exec_() rejections = signal.rejections.get_list() new_rejections[experiment] = rejections with open(new_rejections_file, 'wb') as pkl: pickle.dump(new_rejections, pkl) del a else: print('file exist') with open(new_rejections_file, 'rb') as handle: new_rejections = pickle.load(handle) print(new_rejections)
def plot_results(pilot_dir, subj, channel, alpha_band=(9, 14), theta_band=(3, 6), drop_channels=None, dc=False, reject_alpha=True, normalize_by='opened'): drop_channels = drop_channels or [] cm = get_colors() fg = plt.figure(figsize=(30, 6)) for j_s, experiment in enumerate(subj): with h5py.File('{}\\{}\\{}'.format(pilot_dir, experiment, 'experiment_data.h5')) as f: rejections, top_alpha, top_ica = load_rejections(f, reject_alpha=reject_alpha) fs, channels, p_names = get_info(f, drop_channels) ch = channels.index(channel) #plt.plot(fft_filter(f['protocol6/raw_data'][:, ch], fs, band=(3, 35))) #plt.plot(fft_filter(np.dot(f['protocol6/raw_data'], rejections)[:, ch], fs, band=(3, 35))) #plt.show() #from scipy.signal import welch #plt.plot(*welch(f['protocol1/raw_data'][:60*500//2, channels.index('C3')], fs, nperseg=1000)) #plt.plot(*welch(f['protocol1/raw_data'][60*500//2:, channels.index('C3')], fs, nperseg=1000)) #plt.plot(*welch(f['protocol2/raw_data'][:30*500//2, channels.index('C3')], fs, nperseg=1000)) #plt.plot(*welch(f['protocol2/raw_data'][30*500//2:, channels.index('C3')], fs, nperseg=1000)) #plt.legend(['Close', 'Open', 'Left', 'Right']) #plt.show() # collect powers powers = OrderedDict() raw = OrderedDict() alpha = OrderedDict() pow_theta = [] for j, name in enumerate(p_names): pow, alpha_x, x = get_protocol_power(f, j, fs, rejections, ch, alpha_band, dc=dc) if 'FB' in name: pow_theta.append(get_protocol_power(f, j, fs, rejections, ch, theta_band, dc=dc)[0].mean()) powers = add_data(powers, name, pow, j) raw = add_data(raw, name, x, j) alpha = add_data(alpha, name, alpha_x, j) # plot rejections n_tops = top_ica.shape[1] + top_alpha.shape[1] for j_t in range(top_ica.shape[1]): ax = fg.add_subplot(4, n_tops * len(subj), n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1) ax.set_xlabel('ICA{}'.format(j_t + 1)) labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0]) channels = [label for label in labels if label not in drop_channels] pos = ch_names_to_2d_pos(channels) plot_topomap(data=top_ica[:, j_t], pos=pos, axes=ax, show=False) for j_t in range(top_alpha.shape[1]): ax = fg.add_subplot(4, n_tops * len(subj), n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1 + top_ica.shape[1]) ax.set_xlabel('CSP{}'.format(j_t + 1)) labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0]) channels = [label for label in labels if label not in drop_channels] pos = ch_names_to_2d_pos(channels) plot_topomap(data=top_alpha[:, j_t], pos=pos, axes=ax, show=False) # plot powers if normalize_by == 'opened': norm = powers['1. Opened'].mean() elif normalize_by == 'beta': norm = np.mean(pow_theta) else: print('WARNING: norm = 1') print('norm', norm) ax1 = fg.add_subplot(3, len(subj), j_s + 1) ax = fg.add_subplot(3, len(subj), j_s + len(subj) + 1) t = 0 for j_p, ((name, pow), (name, x)) in enumerate(zip(powers.items(), raw.items())): if name == '2228. FB': from scipy.signal import periodogram fff = plt.figure() fff.gca().plot(*periodogram(x, fs, nfft=fs * 3), c=cm[name.split()[1]]) plt.xlim(0, 80) plt.ylim(0, 3e-11) plt.show() print(name) time = np.arange(t, t + len(x)) / fs color = cm[''.join([i for i in name.split()[1] if not i.isdigit()])] ax1.plot(time, fft_filter(x, fs, (2, 45)), c=color, alpha=0.4) ax1.plot(time, alpha[name], c=color) t += len(x) ax.plot([j_p], [pow.mean() / norm], 'o', c=color, markersize=10) ax.errorbar([j_p], [pow.mean() / norm], yerr=pow.std() / norm, c=color, ecolor=color) fb_x = np.hstack([[j] * len(pows) for j, (key, pows) in enumerate(powers.items()) if 'FB' in key]) fb_y = np.hstack([pows for key, pows in powers.items() if 'FB' in key]) / norm sns.regplot(x=fb_x, y=fb_y, ax=ax, color=cm['FB'], scatter=False, truncate=True) ax1.set_xlim(0, t / fs) ax1.set_ylim(-40, 40) plt.setp(ax.xaxis.get_majorticklabels(), rotation=70) ax.set_xticks(range(len(powers))) ax.set_xticklabels(powers.keys()) ax.set_ylim(0, 3) ax.set_xlim(-1, len(powers)) ax1.set_title('Day {}'.format(j_s + 1)) return fg
fs, channels, p_names = get_info(f, settings['drop_channels']) rejection, alpha, ica = load_rejections(f, reject_alpha=False) raw_before = OrderedDict() raw_after = OrderedDict() for j, name in enumerate(p_names): if j < 3: x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None) raw_before = add_data_simple(raw_before, name, x) elif j > len(p_names) - 3: x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None) print(x.shape) raw_after = add_data_simple(raw_after, name, x) xx = np.concatenate([raw_before['Opened'], raw_before['Baseline'], raw_after['Baseline'], raw_before['Left'], raw_before['Right'], raw_after['Left'], raw_after['Right'], ]) #xx[:, channels.index('C3')] = xx[:, channels.index('C3')] rejection, spatial, topography, unmixing_matrix, bandpass, _ = ICADialog.get_rejection(xx, channels, fs, mode='csp', states=None) from mne.viz import plot_topomap print(spatial) fig, axes = plt.subplots(ncols=rejection.topographies.shape[1]) if not isinstance(axes, type(axes1)) : axes = [axes] for ax, top in zip(axes, rejection.topographies.T): plot_topomap(top, ch_names_to_2d_pos(channels), axes=ax, show=False) fig.savefig('csp_S{}_D{}.png'.format(subj, day+1)) fig.show()
with h5py.File('{}\\{}\\{}'.format(settings['dir'], experiment, 'experiment_data.h5')) as f: fs, channels, p_names = get_info(f, settings['drop_channels']) rejection, alpha, ica = None, None, None#load_rejections(f, reject_alpha=True) odict = OrderedDict() for j, name in enumerate(p_names): x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejection if reject else None) odict = add_data_simple(odict, name, x) raw[names[j_experiment]] = odict for j, key in enumerate(state_plot): f, Pxx = welch(raw[names[j_experiment]][key], fs, nperseg=2048, axis=0) #axes[j].semilogy(f, Pxx, alpha=1, c=cm[j_experiment*2+3]) ax = axes[j, j_experiment] a, b = plot_topomap(np.log10(Pxx[np.argmin(np.abs(f-peak)), :]), ch_names_to_2d_pos(channels), cmap='Reds', axes=ax, show=False, vmax=-10.5, vmin=-13) if j_experiment == 0: ax.set_ylabel(key) if j == len(state_plot)-1: ax.set_xlabel(names[j_experiment]) #axes[j].set_xlim(0, 250) #axes[j].set_ylim(1e-19, 5e-10) #x_plot = np.abs(hilbert(fft_filter(raw_before[key][:, channels.index(ch)], fs))) #leg.append('P={:.3f}, D={:.3f}s'.format(Pxx[(f > 9) & (f < 14)].mean(), sum((x_plot > 5)) / fs/2)) #axes[j].legend(leg) fig2.colorbar(a) plt.show()