def ax_scalp(v, channels, ax=None, annotate=False, vmin=None, vmax=None, cmap=cm.coolwarm, scalp_line_width=1, scalp_line_style='solid', chan_pos_list=CHANNEL_10_20_APPROX, interpolation='bilinear', fontsize=8): """Draw a scalp plot. Draws a scalp plot on an existing axes. The method takes an array of values and an array of the corresponding channel names. It matches the channel names with an channel position list to project them correctly on the scalp. Parameters ---------- v : 1d-array of floats The values for the channels channels : 1d array of strings The corresponding channel names for the values in ``v`` ax : Axes, optional The axes to draw the scalp plot on. If not provided, the currently activated axes (i.e. ``gca()``) will be taken annotate : Boolean, optional Draw the channel names next to the channel markers. vmin, vmax : float, optional The display limits for the values in ``v``. If the data in ``v`` contains values between -3..3 and ``vmin`` and ``vmax`` are set to -1 and 1, all values smaller than -1 and bigger than 1 will appear the same as -1 and 1. If not set, the maximum absolute value in ``v`` is taken to calculate both values. cmap : matplotlib.colors.colormap, optional A colormap to define the color transitions. scalp_line_width: float Line width for outline of scalp scalp_line_style: float Line style for outline of scalp chan_pos_list: iterable of tuples First entry should be 'angle' or 'cartesian', remaining entries 2-tuples of x and y. interpolation: str Returns ------- ax : Axes the axes on which the plot was drawn Notes ----- Code adapted from Wyrm [1]_ toolbox https://github.com/bbci/wyrm. References ---------- .. [1] Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017). Deep learning with convolutional neural networks for EEG decoding and visualization. arXiv preprint arXiv:1703.05051. """ if ax is None: ax = plt.gca() assert len(v) == len(channels), "Should be as many values as channels" assert interpolation == 'bilinear' or interpolation == 'nearest' if vmin is None: # added by me ([email protected]) assert vmax is None vmin, vmax = -np.max(np.abs(v)), np.max(np.abs(v)) # what if we have an unknown channel? points = [get_channelpos(c, chan_pos_list) for c in channels] for c in channels: assert get_channelpos( c, chan_pos_list) is not None, ("Expect " + c + " to exist in positions") z = [v[i] for i in range(len(points))] # calculate the interpolation x = [i[0] for i in points] y = [i[1] for i in points] # interpolate the in-between values xx = np.linspace(min(x), max(x), 500) yy = np.linspace(min(y), max(y), 500) if interpolation == 'bilinear': xx_grid, yy_grid = np.meshgrid(xx, yy) f = interpolate.LinearNDInterpolator(list(zip(x, y)), z) zz = f(xx_grid, yy_grid) else: assert interpolation == 'nearest' f = interpolate.NearestNDInterpolator(list(zip(x, y)), z) assert len(xx) == len(yy) zz = np.ones((len(xx), len(yy))) for i_x in xrange(len(xx)): for i_y in xrange(len(yy)): # somehow this is correct. don't know why :( zz[i_y, i_x] = f(xx[i_x], yy[i_y]) # zz[i_x,i_y] = f(xx[i_x], yy[i_y]) assert not np.any(np.isnan(zz)) # plot map image = ax.imshow(zz, vmin=vmin, vmax=vmax, cmap=cmap, extent=[min(x), max(x), min(y), max(y)], origin='lower', interpolation=interpolation) if scalp_line_width > 0: # paint the head ax.add_artist( plt.Circle((0, 0), 1, linestyle=scalp_line_style, linewidth=scalp_line_width, fill=False)) # add a nose ax.plot([-0.1, 0, 0.1], [1, 1.1, 1], color='black', linewidth=scalp_line_width, linestyle=scalp_line_style) # add ears _add_ears(ax, scalp_line_width, scalp_line_style) # add markers at channels positions # set the axes limits, so the figure is centered on the scalp ax.set_ylim([-1.05, 1.15]) ax.set_xlim([-1.15, 1.15]) # hide the frame and ticks ax.set_frame_on(False) ax.set_xticks([]) ax.set_yticks([]) # draw the channel names if annotate: for i in zip(channels, list(zip(x, y))): ax.annotate(" " + i[0], i[1], horizontalalignment="center", verticalalignment='center', fontsize=fontsize) ax.set_aspect(1) return image
bands = [alpha_band, beta_band, high_gamma_band] for band in bands: band['i_start'] = np.searchsorted(freqs, band['start']) band['i_stop'] = np.searchsorted(freqs, band['stop']) + 1 band['freq_corr'] = np.mean(amp_pred_corrs[:, band['i_start']:band['i_stop']], axis=1) from braindecode.datasets.sensor_positions import get_channelpos, CHANNEL_10_20_APPROX ch_names = [ 'Fz', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'P1', 'Pz', 'P3', 'POz' ] positions = [get_channelpos(name, CHANNEL_10_20_APPROX) for name in ch_names] positions = np.array(positions) import matplotlib.pyplot as plt from matplotlib import cm max_abs_val = np.max([np.abs(band['freq_corr']) for band in bands]) fig, axes = plt.subplots(len(bands), global_vars.get('n_classes')) class_names = ['Left Hand', 'Right Hand', 'Feet', 'Tongue'] for band_i, band in enumerate(bands): for i_class in range(global_vars.get('n_classes')): ax = axes[band_i, i_class] mne.viz.plot_topomap(band['freq_corr'][:, i_class], positions, vmin=-max_abs_val,
ch_names_kara = [ 'Fp1', 'Fpz', 'Fp2', 'Af3', 'Af4', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'Ft7', 'Fc5', 'Fc3', 'Fc1', 'Fcz', 'Fc2', 'Fc4', 'Fc6', 'Ft8', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'Tp7', 'Cp5', 'Cp3', 'Cp1', 'Cpz', 'Cp2', 'Cp4', 'Cp6', 'Tp8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'Po7', 'Po5', 'Po3', 'Poz', 'Po4', 'Po6', 'Po8', 'P9', 'Oz', 'O2', 'P10', 'O1' ] # Cb1 and Cb2 replaced with P9 and P10 respectively ch_names_epoc = [ 'Af3', 'F7', 'F3', 'Fc5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'Fc6', 'F4', 'F8', 'Af4' ] positions = lambda x: np.array( [get_channelpos(name, CHANNEL_10_20_APPROX) for name in x]) # For plotting accuracies over time: labels_per_trial_per_crop = model.predict_classes(test_set.X, individual_crops=True) accs_per_crop = np.mean( [l == y for l, y in zip(labels_per_trial_per_crop, test_set.y)], axis=0) cropped_outs = model.predict_outs(test_set.X, individual_crops=True) ##################################RESULTS DISPLAY FOR PUBLICATION################################### with open('Deep4Net.pkl', 'wb') as f: d = {} d['epochs_df'] = model.epochs_df # Display monitored values as pandas dataframe d['evaluate'] = model.evaluate(