Esempio n. 1
0
                        powers['{}. Closed'.format(j+1)] = pow[:len(pow)//2]
                        powers['{}. Opened'.format(j+1)] = pow[len(pow)//2:]
                    elif name == 'Rotate':
                        powers['{}. Right'.format(j+1)] = pow[:len(pow)//2]
                        powers['{}. Left'.format(j+1)] = pow[len(pow)//2:]
                    else:
                        powers['{}. {}'.format(j+1, name)] = pow


                # plot rejections
                for j_t in range(top_ica.shape[1]):
                    ax = fg.add_subplot(5, top_ica.shape[1]*len(subj), top_ica.shape[1]*len(subj)*3 + top_ica.shape[1]*j_s + j_t + 1)
                    ax.set_xlabel('ICA{}'.format(j_t+1))
                    labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                    channels = [label for label in labels if label not in drop_channels]
                    pos = ch_names_to_2d_pos(channels)
                    plot_topomap(data=top_ica[:, j_t], pos=pos, axes=ax, show=False)
                for j_t in range(top_alpha.shape[1]):
                    ax = fg.add_subplot(5, top_alpha.shape[1]*len(subj), top_alpha.shape[1]*len(subj)*4 + top_alpha.shape[1]*j_s + j_t + 1)
                    ax.set_xlabel('CSP{}'.format(j_t+1))
                    labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                    channels = [label for label in labels if label not in drop_channels]
                    pos = ch_names_to_2d_pos(channels)
                    plot_topomap(data=top_alpha[:, j_t], pos=pos, axes=ax, show=False)


                # plot powers
                norm = powers['{}. Baseline'.format(p_names.index('Baseline') + 1)].mean()
                #norm = np.mean(pow_theta)
                print('norm', norm)
                ax = fg.add_subplot(2, len(subj), j_s + 1)
Esempio n. 2
0
                elif len(lengths_buffer) > 0:
                    lengths.append(np.mean(lengths_buffer))
                    lengths_buffer = []
            print(lengths)
            return np.array(lengths)



        with h5py.File('{}\\{}\\{}'.format(dir_, experiment, 'experiment_data.h5')) as f:
            fs, channels, p_names = get_info(f, settings['drop_channels'])
            if reject:
                rejections = load_rejections(f, reject_alpha=True)[0]
            else:
                rejections = None
            spatial = f['protocol15/signals_stats/left/spatial_filter'][:]
            plot_topomap(spatial, ch_names_to_2d_pos(channels), axes=plt.gca(), show=False)
            plt.savefig('alphaS{}_Day{}_spatial_filter'.format(subj, day+1))
            mu_band = f['protocol15/signals_stats/left/bandpass'][:]
            #mu_band = (12, 13)
            max_gap = 1 / min(mu_band) * 2
            min_sate_duration = max_gap * 2
            raw = OrderedDict()
            signal = OrderedDict()
            for j, name in enumerate(p_names):
                x = preproc(f['protocol{}/raw_data'.format(j + 1)][:], fs, rejections)
                raw = add_data(raw, name, x, j)
                signal = add_data(signal, name, f['protocol{}/signals_data'.format(j + 1)][:], j)

        del raw[list(raw.keys())[-1]]
        # make csp:
        if run_ica:
Esempio n. 3
0
                                      sharey=False)
            for j, data in enumerate([first, last]):
                raw_data = np.concatenate([
                    np.concatenate(
                        [raw[key] for key in data if 'Close' in key]),
                    np.concatenate([raw[key] for key in data if 'Open' in key])
                ])
                rej, spatial, top = ICADialog.get_rejection(raw_data,
                                                            channels,
                                                            fs,
                                                            mode='csp',
                                                            states=None)[:3]
                tops.append(top)
                spats.append(spatial)
                plot_topomap(top,
                             ch_names_to_2d_pos(channels),
                             axes=axes[j, 0],
                             show=False)
                plot_topomap(spatial,
                             ch_names_to_2d_pos(channels),
                             axes=axes[j, 1],
                             show=False)
                axes[j, 0].set_xlabel(
                    'Topography ({})'.format('before' if j == 0 else 'after'))
                axes[j, 1].set_xlabel('Spatial filter ({})'.format(
                    'before' if j == 0 else 'after'))
            #plt.show()

        # plot raw data
        ch_plot = ['C3', 'P3', 'ICA']  #, 'Pz', 'Fp1']
        fig1, axes = plt.subplots(len(ch_plot),
Esempio n. 4
0
    def update_figure(self,
                      data,
                      pos=None,
                      names=None,
                      show_names=None,
                      show_colorbar=True,
                      central_text=None,
                      right_bottom_text=None,
                      show_not_found_symbol=False,
                      montage=None):
        if montage is None:
            if pos is None:
                pos = ch_names_to_2d_pos(names)
        else:
            pos = montage.get_pos('EEG')
            names = montage.get_names('EEG')
        data = np.array(data)
        self.axes.clear()
        if self.colorbar:
            self.colorbar.remove()
        if show_names is None:
            show_names = [
                'O1', 'O2', 'CZ', 'T3', 'T4', 'T7', 'T8', 'FP1', 'FP2'
            ]
        show_names = [name.upper() for name in show_names]
        mask = np.array([name.upper() in show_names
                         for name in names]) if names else None
        v_min, v_max = None, None
        if (data == data[0]).all():
            data[0] += 0.1
            data[1] -= 0.1
            v_min, v_max = -1, 1
        a, b = plot_topomap(data,
                            pos,
                            axes=self.axes,
                            show=False,
                            contours=0,
                            names=names,
                            show_names=True,
                            mask=mask,
                            mask_params=dict(marker='o',
                                             markerfacecolor='w',
                                             markeredgecolor='w',
                                             linewidth=0,
                                             markersize=3),
                            vmin=v_min,
                            vmax=v_max)
        if central_text is not None:
            self.axes.text(0,
                           0,
                           central_text,
                           horizontalalignment='center',
                           verticalalignment='center')

        if right_bottom_text is not None:
            self.axes.text(-0.65,
                           0.65,
                           right_bottom_text,
                           horizontalalignment='left',
                           verticalalignment='top')

        if show_not_found_symbol:
            self.axes.text(0,
                           0,
                           '/',
                           horizontalalignment='center',
                           verticalalignment='center')
            self.axes.text(0,
                           0,
                           'O',
                           size=10,
                           horizontalalignment='center',
                           verticalalignment='center')

        if show_colorbar:
            self.colorbar = self.fig.colorbar(a,
                                              orientation='horizontal',
                                              ax=self.axes)
            self.colorbar.ax.tick_params(labelsize=6)
            self.colorbar.ax.set_xticklabels(
                self.colorbar.ax.get_xticklabels(), rotation=90)
        self.draw()
Esempio n. 5
0
def plot_results(pilot_dir,
                 subj,
                 channel,
                 alpha_band=(9, 14),
                 theta_band=(3, 6),
                 drop_channels=None,
                 dc=False,
                 reject_alpha=True,
                 normalize_by='opened'):
    drop_channels = drop_channels or []
    cm = get_colors()
    fg = plt.figure(figsize=(30, 6))
    for j_s, experiment in enumerate(subj):
        with h5py.File('{}\\{}\\{}'.format(pilot_dir, experiment,
                                           'experiment_data.h5')) as f:
            rejections, top_alpha, top_ica = load_rejections(
                f, reject_alpha=reject_alpha)
            fs, channels, p_names = get_info(f, drop_channels)
            ch = channels.index(channel)

            # collect powers
            powers = OrderedDict()
            raw = OrderedDict()
            alpha = OrderedDict()
            pow_theta = []
            for j, name in enumerate(p_names):
                pow, alpha_x, x = get_protocol_power(f,
                                                     j,
                                                     fs,
                                                     rejections,
                                                     ch,
                                                     alpha_band,
                                                     dc=dc)
                if 'FB' in name:
                    pow_theta.append(
                        get_protocol_power(f,
                                           j,
                                           fs,
                                           rejections,
                                           ch,
                                           theta_band,
                                           dc=dc)[0].mean())
                powers = add_data(powers, name, pow, j)
                raw = add_data(raw, name, x, j)
                alpha = add_data(alpha, name, alpha_x, j)

            # plot rejections
            n_tops = top_ica.shape[1]  #+ top_alpha.shape[1]
            for j_t in range(top_ica.shape[1]):
                ax = fg.add_subplot(
                    4, n_tops * len(subj),
                    n_tops * len(subj) * 3 + n_tops * j_s + j_t + 1)
                ax.set_xlabel('ICA{}'.format(j_t + 1) if j_t < top_alpha else
                              'CSP{}'.format(-top_alpha + j_t + 1))
                labels, fs = get_lsl_info_from_xml(f['stream_info.xml'][0])
                channels = [
                    label for label in labels if label not in drop_channels
                ]
                pos = ch_names_to_2d_pos(channels)
                plot_topomap(data=top_ica[:, j_t],
                             pos=pos,
                             axes=ax,
                             show=False)

            # plot powers
            if normalize_by == 'opened':
                norm = powers['1. Opened'].mean()
            elif normalize_by == 'beta':
                norm = np.mean(pow_theta)
            else:
                print('WARNING: norm = 1')
            print('norm', norm)

            ax1 = fg.add_subplot(3, len(subj), j_s + 1)
            ax = fg.add_subplot(3, len(subj), j_s + len(subj) + 1)
            t = 0
            #print(powers.keys())
            if j_s == 0:
                print(powers['4. FB'].mean() / powers['3. Baseline'].mean())
            for j_p, ((name, pow),
                      (name, x)) in enumerate(zip(powers.items(),
                                                  raw.items())):
                if name == '2228. FB':
                    from scipy.signal import periodogram
                    fff = plt.figure()
                    fff.gca().plot(*periodogram(x, fs, nfft=fs * 3),
                                   c=cm[name.split()[1]])
                    plt.xlim(0, 80)
                    plt.ylim(0, 3e-11)
                    plt.show()
                #print(name)
                time = np.arange(t, t + len(x)) / fs
                ax1.plot(time, x, c=cm[name.split()[1]], alpha=0.4)
                ax1.plot(time, alpha[name], c=cm[name.split()[1]])
                t += len(x)
                ax.plot([j_p], [pow.mean() / norm],
                        'o',
                        c=cm[name.split()[1]],
                        markersize=10)
                c = cm[name.split()[1]]
                ax.errorbar([j_p], [pow.mean() / norm],
                            yerr=pow.std() / norm,
                            c=c,
                            ecolor=c)
            fb_x = np.hstack([[j] * len(pows)
                              for j, (key, pows) in enumerate(powers.items())
                              if 'FB' in key])
            fb_y = np.hstack(
                [pows for key, pows in powers.items() if 'FB' in key]) / norm
            sns.regplot(x=fb_x,
                        y=fb_y,
                        ax=ax,
                        color=cm['FB'],
                        scatter=False,
                        truncate=True)

            ax1.set_xlim(0, t / fs)
            ax1.set_ylim(-40, 40)
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=70)
            ax.set_xticks(range(len(powers)))
            ax.set_xticklabels(powers.keys())
            ax.set_ylim(0, 3)
            ax.set_xlim(-1, len(powers))
            ax1.set_title('Day {}'.format(j_s + 1))
    return fg