Exemplo n.º 1
0
 def fs_fft_output(self, layer_name):
     layer_name = self._check_layer_name(layer_name)
     _input = self.model.input
     _output = self.model.output
     for name in layer_name:
         layer = self.layer_dict[name]
         fig, axs = plt.subplots(layer.output.shape[-1] // 2,
                                 2,
                                 sharex=True)
         layer_model = tf.keras.Model([_input], [layer.output, _output])
         with tf.GradientTape() as g:
             conv_output, Pred = layer_model(self.data['x'])
             index = np.argmax(Pred[0])
             prob = Pred[:, index]
             grads = g.gradient(prob, conv_output)
             pooled_grads = K.sum(grads, axis=(0, 1, 2))
         selected = tf.reduce_mean(tf.multiply(pooled_grads, conv_output),
                                   axis=0)
         for i in np.arange(selected.shape[-1]):
             fred = np.abs(fft(np.array(selected[:, :, i]), axis=1))
             fred = fred / len(fred.T)
             fred = fred[:, :101]
             axs[i // 2, i % 2].set_prop_cycle('color', [
                 plt.cm.Spectral_r(i) for i in np.linspace(0, 1, len(fred))
             ])
             for col in np.arange(len(fred)):
                 axs[i // 2, i % 2].plot(np.arange(len(fred.T)), fred[col])
             axs[i // 2, i % 2].autoscale(enable=True,
                                          axis='both',
                                          tight=True)
             axs[i // 2, i % 2].set_xlabel(chr(ord('a') + i))
         plt.subplots_adjust(right=0.98,
                             left=0.05,
                             top=0.99,
                             bottom=0.09,
                             wspace=0.15,
                             hspace=0.5)
         plt.show(block=False)
         a = np.average((np.abs(fft(np.array(selected), axis=1)) /
                         selected.shape[1])[:, :101, :],
                        axis=(0, -1))
         plt.figure()
         plt.plot(np.arange(len(a)), a)
         plt.autoscale(enable=True, axis='both', tight=True)
         plt.tight_layout()
         plt.margins(0, 0)
         fig.savefig(os.path.join('fs_fft_output.png'),
                     format='png',
                     transparent=False,
                     dpi=300,
                     pad_inches=0)
         plt.show(block=False)
Exemplo n.º 2
0
 def fs_class_topo_kernel(self, layer_name):
     layer_name = self._check_layer_name(layer_name)
     class_data = self._class_data(self.data)
     _input = self.model.input
     _output = self.model.output
     plt.rcParams['font.size'] = 12
     for name in layer_name:
         layer = self.layer_dict[name]
         _weights = layer.get_weights()[0]
         layer_model = tf.keras.Model([_input],
                                      [layer.input, layer.output, _output])
         fig = plt.figure()
         for c in class_data:
             with tf.GradientTape() as g:
                 conv_input, conv_output, Pred = layer_model(
                     class_data[c]['x'])
                 prob = Pred[:, c]
                 grads = g.gradient(prob, conv_output)
                 pooled_grads = K.sum(grads, axis=(0, 1, 2))
                 pooled_grads = tf.reshape(pooled_grads,
                                           shape=(_weights.shape[-2],
                                                  _weights.shape[-1]))
             s_weights = tf.reduce_mean(tf.abs(
                 tf.multiply(pooled_grads, _weights)),
                                        axis=(1, 2, 3))
             ax = fig.add_subplot(2, 2, c + 1)
             ax.set_xlabel('({}) {}'.format(chr(c + 97),
                                            self.class_names[c]))
             viz.plot_topomap(np.array(s_weights),
                              self.locs,
                              names=self.sensors_name,
                              show_names=True,
                              show=False,
                              image_interp='bicubic',
                              cmap='Spectral_r',
                              extrapolate='head',
                              sphere=(0, 0, 0, 1))  # draw topographic image
         plt.tight_layout(pad=0)
         plt.margins(0, 0)
         fig.savefig(os.path.join('fs_class_topo_kernel.png'),
                     format='png',
                     transparent=False,
                     dpi=300,
                     pad_inches=0)
         plt.show(block=False)
Exemplo n.º 3
0
 def fs_class_fft_output(self, layer_name):
     layer_name = self._check_layer_name(layer_name)
     class_data = self._class_data(self.data)
     _input = self.model.input
     _output = self.model.output
     plt.rcParams['font.size'] = 12
     for name in layer_name:
         layer = self.layer_dict[name]
         layer_model = tf.keras.Model([_input], [layer.output, _output])
         fig, axs = plt.subplots(2, 2, sharex=True, sharey=True)
         for c in class_data:
             with tf.GradientTape() as g:
                 conv_output, Pred = layer_model(class_data[c]['x'])
                 prob = Pred[:, c]
                 grads = g.gradient(prob, conv_output)
                 pooled_grads = K.sum(grads, axis=(0, 1, 2))
             selected = tf.multiply(pooled_grads, conv_output)
             fred = np.average(np.abs(fft(np.array(selected), axis=2)),
                               axis=(0, 1, -1))
             fred = fred / len(fred.T)
             fred = fred[:101]
             axs[c // 2, c % 2].plot(np.arange(len(fred.T)), fred.T)
             axs[c // 2,
                 c % 2].set_xlabel('({}) {}'.format(chr(c + 97),
                                                    self.class_names[c]))
             axs[c // 2, c % 2].autoscale(enable=True,
                                          axis='both',
                                          tight=True)
         plt.subplots_adjust(right=0.97,
                             left=0.05,
                             top=0.96,
                             bottom=0.10,
                             wspace=0.13,
                             hspace=0.10)
         plt.margins(0, 0)
         fig.savefig(os.path.join('fs_class_fft_output.png'),
                     format='png',
                     transparent=False,
                     dpi=300,
                     pad_inches=0)
         plt.show(block=False)
Exemplo n.º 4
0
 def fs_class_fft_output(self, layer_name):
     layer_name = self._check_layer_name(layer_name)
     class_data = self._class_data(self.data)
     _input = self.model.input
     _output = self.model.output
     plt.rcParams['font.size'] = 12
     for name in layer_name:
         layer = self.layer_dict[name]
         layer_model = tf.keras.Model([_input], [layer.output, _output])
         fig, axs = plt.subplots()
         selected = []
         for c in class_data:
             with tf.GradientTape() as g:
                 conv_output, Pred = layer_model(class_data[c]['x'])
                 prob = Pred[:, c]
                 grads = g.gradient(prob, conv_output)
             pooled_grads = K.sum(grads, axis=(0, 1, 2))
             selected.append(tf.multiply(pooled_grads, conv_output))
         axs.set_prop_cycle('color', 'rbgy')
         for c in class_data:
             fred = np.mean(np.abs(fft(np.array(selected[c]), axis=2)),
                            axis=(0, 1, -1))
             fred = fred / len(fred.T)
             fred = fred[:101]
             axs.plot(np.arange(len(fred.T)),
                      fred.T,
                      label='({}) {}'.format(chr(c + 97),
                                             self.class_names[c]))
             axs.set_xlabel('Frequency /Hz')
             axs.set_ylabel('Amplitude')
         axs.autoscale(enable=True, axis='both', tight=True)
         plt.legend(loc='upper right')
         plt.tight_layout(pad=0.25, h_pad=0, w_pad=1)
         fig.savefig(os.path.join('fs_class_fft_output.png'),
                     format='png',
                     transparent=False,
                     dpi=300,
                     pad_inches=0)
         plt.show(block=False)
Exemplo n.º 5
0
 def fs_class_freq_topo_kernel(self, layer_name):
     layer_name = self._check_layer_name(layer_name)
     class_data = self._class_data(self.data)
     _input = self.model.input
     _output = self.model.output
     ib = ['2-8', '8-12', '12-20', '20-30', '30-60']
     plt.rcParams['font.size'] = 12
     for name in layer_name:
         layer = self.layer_dict[name]
         _weights = layer.get_weights()[0]
         layer_model = tf.keras.Model([_input], [layer.output, _output])
         fig = plt.figure(figsize=(8, 6))
         gs = fig.add_gridspec(2, 2)
         for c in class_data:
             axs = fig.add_subplot(gs[c])
             ax = inset_axes(axs,
                             '100%',
                             '100%',
                             bbox_to_anchor=(0, 0.25, 1, 0.5),
                             bbox_transform=axs.transAxes,
                             borderpad=0)
             cax = inset_axes(axs,
                              '100%',
                              '100%',
                              bbox_to_anchor=(0.65, 0.65, 0.09, 0.03),
                              bbox_transform=axs.transAxes,
                              borderpad=0)
             text_0 = inset_axes(axs,
                                 '100%',
                                 '100%',
                                 bbox_to_anchor=(0.62, 0.65, 0.03, 0.03),
                                 bbox_transform=axs.transAxes,
                                 borderpad=0)
             text_1 = inset_axes(axs,
                                 '100%',
                                 '100%',
                                 bbox_to_anchor=(0.74, 0.65, 0.03, 0.03),
                                 bbox_transform=axs.transAxes,
                                 borderpad=0)
             cax_text = inset_axes(axs,
                                   '100%',
                                   '100%',
                                   bbox_to_anchor=(0.77, 0.65, 0.23, 0.03),
                                   bbox_transform=axs.transAxes,
                                   borderpad=0)
             title = inset_axes(axs,
                                '100%',
                                '100%',
                                bbox_to_anchor=(0, 0.65, 0.62, 0.03),
                                bbox_transform=axs.transAxes,
                                borderpad=0)
             self._ban_axis(axs)
             self._ban_axis(ax)
             self._ban_axis(cax)
             self._ban_axis(cax_text)
             self._ban_axis(text_0)
             self._ban_axis(text_1)
             self._ban_axis(title)
             ax.set_xlabel('({}) {}'.format(chr(c + 97),
                                            self.class_names[c]))
             ibclass_data = interestingband(class_data[c]['x'],
                                            ib,
                                            axis=-2,
                                            swapaxes=False)
             cax_text.text(0.5,
                           0.5,
                           'contribution',
                           horizontalalignment='center',
                           verticalalignment='center',
                           transform=cax_text.transAxes)
             text_0.text(0,
                         0.5,
                         '0',
                         horizontalalignment='left',
                         verticalalignment='center',
                         transform=text_0.transAxes)
             text_1.text(1,
                         0.5,
                         '1',
                         horizontalalignment='right',
                         verticalalignment='center',
                         transform=text_1.transAxes)
             title.text(0.02,
                        0.5,
                        'Inter-band topomap of class',
                        horizontalalignment='left',
                        verticalalignment='center',
                        transform=title.transAxes)
             s_weights = []
             for i in np.arange(ibclass_data.shape[0]):
                 with tf.GradientTape() as g:
                     conv_output, Pred = layer_model(ibclass_data[i])
                     prob = Pred[:, c]
                     grads = g.gradient(prob, conv_output)
                 pooled_grads = K.sum(grads, axis=(0, 1, 2))
                 pooled_grads = tf.reshape(pooled_grads,
                                           shape=(_weights.shape[-2],
                                                  _weights.shape[-1]))
                 s_weights.append(
                     np.mean(np.abs(np.array(pooled_grads * _weights)),
                             axis=(1, 2, 3)))
             s_weights = normalization(np.array(s_weights), axis=None)
             s_weights = [
                 s_weights[i, :] for i in range(s_weights.shape[0])
             ]
             for i in np.arange(len(s_weights)):
                 width = 1. / len(s_weights)
                 ax_i = inset_axes(ax,
                                   '100%',
                                   '100%',
                                   bbox_to_anchor=(0 + width * i, 0, width,
                                                   1),
                                   bbox_transform=ax.transAxes,
                                   borderpad=0)
                 ax_i.set_xlabel('{}'.format(ib[i] + 'Hz'))
                 # self._ban_axis(ax_i)
                 im, cn = viz.plot_topomap(
                     s_weights[i],
                     self.locs,
                     names=self.sensors_name,
                     show_names=True,
                     show=False,
                     image_interp='bicubic',
                     cmap='RdBu_r',
                     extrapolate='head',
                     sphere=(0, 0, 0, 1))  # draw topographic image
             cbar = plt.colorbar(im, cax=cax, orientation='horizontal')
             cbar.set_ticks([])
         plt.tight_layout(pad=0.25, h_pad=0, w_pad=1)
         fig.savefig(os.path.join('fs_class_freq_topo_kernel.png'),
                     format='png',
                     transparent=False,
                     dpi=300,
                     pad_inches=0)
         plt.show(block=False)