def main(): ''' * :ref:`API in English <lif_fc_mnist.main-en>` .. _lif_fc_mnist.main-cn: :return: None 使用全连接-LIF-全连接-LIF的网络结构,进行MNIST识别。这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。 * :ref:`中文API <lif_fc_mnist.main-cn>` .. _lif_fc_mnist.main-en: The network with FC-LIF-FC-LIF structure for classifying MNIST. This function initials the network, starts training and shows accuracy on test dataset. ''' device = input( '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ') dataset_dir = input( '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ' ) batch_size = int( input('输入batch_size,例如“64”\n input batch_size, e.g., "64": ')) learning_rate = float( input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": ')) T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": ')) tau = float( input( '输入LIF神经元的时间常数tau,例如“100.0”\n input membrane time constant, tau, for LIF neurons, e.g., "100.0": ' )) train_epoch = int( input( '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": ')) log_dir = input( '输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": ' ) writer = SummaryWriter(log_dir) # 初始化数据加载器 train_dataset = torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True) test_dataset = torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True) train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) # 定义并初始化网络 net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau)) net = net.to(device) # 使用Adam优化器 optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # 使用泊松编码器 encoder = encoding.PoissonEncoder() train_times = 0 max_test_accuracy = 0 test_accs = [] train_accs = [] for epoch in range(train_epoch): net.train() for img, label in tqdm(train_data_loader): img = img.to(device) label = label.to(device) label_one_hot = F.one_hot(label, 10).float() optimizer.zero_grad() # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数 for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率 out_spikes_counter_frequency = out_spikes_counter / T # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0 loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot) loss.backward() optimizer.step() # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的 functional.reset_net(net) # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果 accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to( device)).float().mean().item() writer.add_scalar('train_accuracy', accuracy, train_times) train_accs.append(accuracy) train_times += 1 net.eval() with torch.no_grad(): # 每遍历一次全部数据集,就在测试集上测试一次 test_sum = 0 correct_sum = 0 for img, label in test_data_loader: img = img.to(device) for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) correct_sum += (out_spikes_counter.max(1)[1] == label.to( device)).float().sum().item() test_sum += label.numel() functional.reset_net(net) test_accuracy = correct_sum / test_sum writer.add_scalar('test_accuracy', test_accuracy, epoch) test_accs.append(test_accuracy) max_test_accuracy = max(max_test_accuracy, test_accuracy) print( f'Epoch {epoch}: device={device}, dataset_dir={dataset_dir}, batch_size={batch_size}, learning_rate={learning_rate}, T={T}, log_dir={log_dir}, max_test_accuracy={max_test_accuracy}, train_times={train_times}' ) # 保存绘图用数据 net.eval() functional.set_monitor(net, True) with torch.no_grad(): img, label = test_dataset[0] img = img.to(device) for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) out_spikes_counter_frequency = (out_spikes_counter / T).cpu().numpy() print(f'Firing rate: {out_spikes_counter_frequency}') output_layer = net[-1] # 输出层 v_t_array = np.asarray(output_layer.monitor['v']).squeeze( ).T # v_t_array[i][j]表示神经元i在j时刻的电压值 np.save("v_t_array.npy", v_t_array) s_t_array = np.asarray(output_layer.monitor['s']).squeeze( ).T # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1 np.save("s_t_array.npy", s_t_array) train_accs = np.array(train_accs) np.save('train_accs.npy', train_accs) test_accs = np.array(test_accs) np.save('test_accs.npy', test_accs)
def play(use_cuda, pt_path, env_name, hidden_size, played_frames=60, save_fig_num=0, fig_dir=None, figsize=(12, 6), firing_rates_plot_type='bar', heatmap_shape=None): T = 16 plt.rcParams['figure.figsize'] = figsize plt.ion() env = gym.make(env_name).unwrapped device = torch.device("cuda" if use_cuda else "cpu") n_states = env.observation_space.shape[0] n_actions = env.action_space.n policy_net = DQSN(n_states, hidden_size, n_actions, T).to(device) policy_net.load_state_dict(torch.load(pt_path, map_location=device)) env.reset() state = torch.zeros([1, n_states], dtype=torch.float, device=device) with torch.no_grad(): functional.set_monitor(policy_net, True) delta_lim = 0 over_score = 1e9 for i in count(): LIF_v = policy_net(state) # shape=[1, 2] action = LIF_v.max(1)[1].view(1, 1).item() if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (1, 0), colspan=3) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (1, 0)) plt.xticks(np.arange(2), ('Left', 'Right')) plt.ylabel('Voltage') plt.title('Voltage of LIF neurons at last time step') delta_lim = (LIF_v.max() - LIF_v.min()) * 0.5 plt.ylim(LIF_v.min() - delta_lim, LIF_v.max() + delta_lim) plt.yticks([]) plt.text(0, LIF_v[0][0], str(round(LIF_v[0][0].item(), 2)), ha='center') plt.text(1, LIF_v[0][1], str(round(LIF_v[0][1].item(), 2)), ha='center') plt.bar(np.arange(2), LIF_v.squeeze(), color=['r', 'gray'] if action == 0 else ['gray', 'r'], width=0.5) if LIF_v.min() - delta_lim < 0: plt.axhline(0, color='black', linewidth=0.1) IF_spikes = np.asarray( policy_net.fc[1].monitor['s']) # shape=[16, 1, 256] firing_rates = IF_spikes.mean(axis=0).squeeze() if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (0, 4), rowspan=2, colspan=5) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (0, 1), rowspan=2, colspan=2) plt.title('Firing rates of IF neurons') if firing_rates_plot_type == 'bar': # 绘制柱状图 plt.xlabel('Neuron index') plt.ylabel('Firing rate') plt.xlim(0, firing_rates.size) plt.ylim(0, 1.01) plt.bar(np.arange(firing_rates.size), firing_rates, width=0.5) elif firing_rates_plot_type == 'heatmap': # 绘制热力图 heatmap = plt.imshow(firing_rates.reshape(heatmap_shape), vmin=0, vmax=1, cmap='ocean') plt.gca().invert_yaxis() cbar = heatmap.figure.colorbar(heatmap) cbar.ax.set_ylabel('Magnitude', rotation=90, va='top') functional.reset_net(policy_net) subtitle = f'Position={state[0][0].item(): .2f}, Velocity={state[0][1].item(): .2f}, Pole Angle={state[0][2].item(): .2f}, Pole Velocity At Tip={state[0][3].item(): .2f}, Score={i}' state, reward, done, _ = env.step(action) if done: over_score = min(over_score, i) subtitle = f'Game over, Score={over_score}' plt.suptitle(subtitle) state = torch.from_numpy(state).float().to(device).unsqueeze(0) screen = env.render(mode='rgb_array').copy() screen[300, :, :] = 0 # 画出黑线 if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (0, 0), colspan=3) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (0, 0)) plt.xticks([]) plt.yticks([]) plt.title('Game screen') plt.imshow(screen, interpolation='bicubic') plt.pause(0.001) if i < save_fig_num: plt.savefig(os.path.join(fig_dir, f'{i}.png')) if done and i >= played_frames: env.close() plt.close() break ''' train(use_cuda=False, model_dir='./model/CartPole-v0/state', log_dir='./log', env_name='CartPole-v0', \ hidden_size=256, num_episodes=500, seed=1) ''' play(use_cuda=False, pt_path='./model/CartPole-v0/policy_net_256_max.pt', env_name='CartPole-v0', \ hidden_size=256, played_frames=300)
# Omitting pruning for all BN layers BN_list = ['static_conv.1', 'conv.2', 'conv.5', 'conv.9', 'conv.12', 'conv.15'] for name, param in net.named_parameters(): if any(BN_name in name for BN_name in BN_list): bn_params += [param] ttl_cnt += param.numel() else: weight_params += [param] w_cnt += param.numel() ttl_cnt += param.numel() ###### TEST MODE ###### if test: with torch.no_grad(): # Turn on monitor set_monitor(net, True) # Record total spike times by layer spike_times = dict() for name, module in net.named_modules(): if hasattr(module, 'monitor'): spike_times[name] = 0 test_sum = 0 correct_sum = 0 for img, label in test_data_loader: img = img.cuda(non_blocking=True) label = label.cuda(non_blocking=True)