def simulate_snn(snn, device, data_loader, T, possion=False): ''' * :ref:`API in English <simulate_snn-en>` .. _simulate_snn-cn: :param snn: SNN模型 :param device: 运行的设备 :param data_loader: 测试数据加载器 :param T: 仿真时长 :param possion: 当设置为 ``True`` ,输入采用泊松编码器;否则,采用恒定输入并持续T时间步 :return: SNN模拟的准确率 对SNN分类性能进行评估,并返回准确率 * :ref:`中文API <simulate_snn-cn>` .. _simulate_snn-en: :param snn: SNN model :param device: running device :param data_loader: testing data loader :param T: simulating steps :param possion: when ``True``, use Possion encoder; otherwise, use constant input over T steps :return: SNN simulating accuracy ''' if possion: encoder = encoding.PoissonEncoder() correct_t = {} with torch.no_grad(): snn.eval() correct = 0.0 total = 0.0 for batch, (img, label) in enumerate(data_loader): img = img.to(device) for t in range(T): if possion: img = encoder(img).float() if t == 0: out_spikes_counter = snn(img) else: out_spikes_counter += snn(img) if t not in correct_t.keys(): correct_t[t] = (out_spikes_counter.max(1)[1] == label.to( device)).float().sum().item() else: correct_t[t] += (out_spikes_counter.max(1)[1] == label.to( device)).float().sum().item() correct += (out_spikes_counter.max(1)[1] == label.to(device) ).float().sum().item() total += label.numel() snn.reset_() print('[SNN Simulating... %.2f%%] Acc:%.3f' % (100 * batch / (len(data_loader)), correct / total)) acc = correct / total print('SNN Simulating Accuracy:%.3f' % (acc)) return acc
def __init__(self, snn, device, name='', **kargs): snn.eval() try: self.log_dir = kargs['log_dir'] except KeyError: from datetime import datetime current_time = datetime.now().strftime('%b%d_%H-%M-%S') log_dir = os.path.join( self.__class__.__name__ + '-' + current_time + ('' if len(name)==0 else '_' + name)) self.log_dir = log_dir print('simulator log_dir:',self.log_dir) if not os.path.isdir(self.log_dir): os.makedirs(self.log_dir) try: encoder = kargs['encoder'] except KeyError: encoder = 'constant' if encoder == 'poisson': self.encoder = encoding.PoissonEncoder() else: self.encoder = lambda x: x if isinstance(device,(list,set,tuple)): if len(device)==1: device = device[0] self.pi = False else: self.pi = True # parallel inference else: self.pi = False if self.pi: print('simulator is working on the parallel mode, device(s):', device) else: print('simulator is working on the normal mode, device:', device) self.device = device global global_shared, mutex_schedule, mutex_shared self.mutex_shared = mutex_shared self.mutex_schedule = mutex_schedule self.global_shared = global_shared if self.pi: self.global_shared['device_used'] = defaultdict(int) self.global_shared['device_stat'] = defaultdict(int) self.global_shared['distri_model'] = {} self.global_shared['batch'] = 0 self.global_shared['batch_sum'] = 0 self.global_shared['T'] = None for dev in self.device: self.global_shared['distri_model'][dev] = copy.deepcopy(snn).to(dev) else: self.global_shared['distri_model'] = {} self.global_shared['distri_model'][self.device] = copy.deepcopy(snn).to(self.device) self.config = dict() self.config['device'] = self.device self.config['name'] = name self.config['log_dir'] = self.log_dir self.config = {**self.config, **kargs}
def main(): ''' * :ref:`API in English <lif_fc_mnist.main-en>` .. _lif_fc_mnist.main-cn: :return: None 使用全连接-LIF-全连接-LIF的网络结构,进行MNIST识别。这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。 * :ref:`中文API <lif_fc_mnist.main-cn>` .. _lif_fc_mnist.main-en: The network with FC-LIF-FC-LIF structure for classifying MNIST. This function initials the network, starts training and shows accuracy on test dataset. ''' device = input( '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ') dataset_dir = input( '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ' ) batch_size = int( input('输入batch_size,例如“64”\n input batch_size, e.g., "64": ')) learning_rate = float( input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": ')) T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": ')) tau = float( input( '输入LIF神经元的时间常数tau,例如“100.0”\n input membrane time constant, tau, for LIF neurons, e.g., "100.0": ' )) train_epoch = int( input( '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": ')) log_dir = input( '输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": ' ) writer = SummaryWriter(log_dir) # 初始化数据加载器 train_dataset = torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True) test_dataset = torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True) train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=False) # 定义并初始化网络 net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau)) net = net.to(device) # 使用Adam优化器 optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # 使用泊松编码器 encoder = encoding.PoissonEncoder() train_times = 0 max_test_accuracy = 0 test_accs = [] train_accs = [] for epoch in range(train_epoch): net.train() for img, label in tqdm(train_data_loader): img = img.to(device) label = label.to(device) label_one_hot = F.one_hot(label, 10).float() optimizer.zero_grad() # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数 for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率 out_spikes_counter_frequency = out_spikes_counter / T # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0 loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot) loss.backward() optimizer.step() # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的 functional.reset_net(net) # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果 accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to( device)).float().mean().item() writer.add_scalar('train_accuracy', accuracy, train_times) train_accs.append(accuracy) train_times += 1 net.eval() with torch.no_grad(): # 每遍历一次全部数据集,就在测试集上测试一次 test_sum = 0 correct_sum = 0 for img, label in test_data_loader: img = img.to(device) for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) correct_sum += (out_spikes_counter.max(1)[1] == label.to( device)).float().sum().item() test_sum += label.numel() functional.reset_net(net) test_accuracy = correct_sum / test_sum writer.add_scalar('test_accuracy', test_accuracy, epoch) test_accs.append(test_accuracy) max_test_accuracy = max(max_test_accuracy, test_accuracy) print( f'Epoch {epoch}: device={device}, dataset_dir={dataset_dir}, batch_size={batch_size}, learning_rate={learning_rate}, T={T}, log_dir={log_dir}, max_test_accuracy={max_test_accuracy}, train_times={train_times}' ) # 保存绘图用数据 net.eval() functional.set_monitor(net, True) with torch.no_grad(): img, label = test_dataset[0] img = img.to(device) for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) out_spikes_counter_frequency = (out_spikes_counter / T).cpu().numpy() print(f'Firing rate: {out_spikes_counter_frequency}') output_layer = net[-1] # 输出层 v_t_array = np.asarray(output_layer.monitor['v']).squeeze( ).T # v_t_array[i][j]表示神经元i在j时刻的电压值 np.save("v_t_array.npy", v_t_array) s_t_array = np.asarray(output_layer.monitor['s']).squeeze( ).T # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1 np.save("s_t_array.npy", s_t_array) train_accs = np.array(train_accs) np.save('train_accs.npy', train_accs) test_accs = np.array(test_accs) np.save('test_accs.npy', test_accs)
def simulate_snn(snn, device, data_loader, T, poisson=False, online_draw=False, fig_name='default', ann_baseline=0, save_acc_list=False, log_dir=None): # TODO ugly ''' * :ref:`API in English <simulate_snn-en>` .. _simulate_snn-cn: :param snn: SNN模型 :param device: 运行的设备 :param data_loader: 测试数据加载器 :param T: 仿真时长 :param poisson: 当设置为 ``True`` ,输入采用泊松编码器;否则,采用恒定输入并持续T时间步 :return: SNN模拟的准确率 对SNN分类性能进行评估,并返回准确率 * :ref:`中文API <simulate_snn-cn>` .. _simulate_snn-en: :param snn: SNN model :param device: running device :param data_loader: testing data loader :param T: simulating steps :param poisson: when ``True``, use poisson encoder; otherwise, use constant input over T steps :return: SNN simulating accuracy ''' functional.reset_net(snn) if poisson: encoder = encoding.PoissonEncoder() correct_t = {} if online_draw: plt.ion() with torch.no_grad(): snn.eval() correct = 0.0 total = 0.0 for batch, (img, label) in enumerate(data_loader): img = img.to(device) for t in tqdm.tqdm(range(T)): encoded = encoder(img).float() if poisson else img out = snn(encoded) if isinstance(out, tuple) or isinstance(out, list): out = out[0] if t == 0: out_spikes_counter = out else: out_spikes_counter += out if t not in correct_t.keys(): correct_t[t] = (out_spikes_counter.max(1)[1] == label.to( device)).float().sum().item() else: correct_t[t] += (out_spikes_counter.max(1)[1] == label.to( device)).float().sum().item() correct += (out_spikes_counter.max(1)[1] == label.to(device) ).float().sum().item() total += label.numel() functional.reset_net(snn) if online_draw: plt.cla() x = np.array(list(correct_t.keys())).astype(np.float32) + 1 y = np.array(list(correct_t.values())).astype( np.float32) / total * 100 plt.plot(x, y, label='SNN', c='b') if ann_baseline != 0: plt.plot(x, np.ones_like(x) * ann_baseline, label='ANN', c='g', linestyle=':') plt.text(0, ann_baseline + 1, "%.3f%%" % (ann_baseline), fontdict={ 'size': '8', 'color': 'g' }) plt.title("%s SNN Simulation \n[test samples:%.1f%%]" % (fig_name, 100 * total / len(data_loader.dataset))) plt.xlabel("T") plt.ylabel("Accuracy(%)") plt.legend() argmax = np.argmax(y) disp_bias = 0.3 * float(T) if x[argmax] / T > 0.7 else 0 plt.text(x[argmax] - 0.8 - disp_bias, y[argmax] + 0.8, "MAX:%.3f%% T=%d" % (y[argmax], x[argmax]), fontdict={ 'size': '12', 'color': 'r' }) plt.scatter([x[argmax]], [y[argmax]], c='r') plt.pause(0.01) if isinstance(log_dir, str): plt.savefig(log_dir + '/' + fig_name + ".pdf") print('[SNN Simulating... %.2f%%] Acc:%.3f' % (100 * total / len(data_loader.dataset), correct / total)) if save_acc_list: acc_list = np.array(list(correct_t.values())).astype( np.float32) / total * 100 np.save( log_dir + '/snn_acc-list' + ('-poisson' if poisson else '-constant'), acc_list) acc = correct / total print('SNN Simulating Accuracy:%.3f' % (acc)) if online_draw: plt.savefig(log_dir + '/' + fig_name + ".pdf") plt.ioff() plt.close() return acc