Beispiel #1
0
def main():
    '''
    * :ref:`API in English <lif_fc_mnist.main-en>`

    .. _lif_fc_mnist.main-cn:

    :return: None

    使用全连接-LIF-全连接-LIF的网络结构,进行MNIST识别。这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。

    * :ref:`中文API <lif_fc_mnist.main-cn>`

    .. _lif_fc_mnist.main-en:

    The network with FC-LIF-FC-LIF structure for classifying MNIST. This function initials the network, starts training
    and shows accuracy on test dataset.
    '''
    device = input(
        '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
    dataset_dir = input(
        '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": '
    )
    batch_size = int(
        input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
    learning_rate = float(
        input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
    T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
    tau = float(
        input(
            '输入LIF神经元的时间常数tau,例如“100.0”\n input membrane time constant, tau, for LIF neurons, e.g., "100.0": '
        ))
    train_epoch = int(
        input(
            '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": '))
    log_dir = input(
        '输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": '
    )

    writer = SummaryWriter(log_dir)

    # 初始化数据加载器
    train_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    test_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True)

    train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    drop_last=True)
    test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   drop_last=False)

    # 定义并初始化网络
    net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False),
                        neuron.LIFNode(tau=tau))
    net = net.to(device)
    # 使用Adam优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    # 使用泊松编码器
    encoder = encoding.PoissonEncoder()
    train_times = 0
    max_test_accuracy = 0

    test_accs = []
    train_accs = []

    for epoch in range(train_epoch):
        net.train()
        for img, label in tqdm(train_data_loader):
            img = img.to(device)
            label = label.to(device)
            label_one_hot = F.one_hot(label, 10).float()

            optimizer.zero_grad()

            # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor
            # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数
            for t in range(T):
                if t == 0:
                    out_spikes_counter = net(encoder(img).float())
                else:
                    out_spikes_counter += net(encoder(img).float())

            # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率
            out_spikes_counter_frequency = out_spikes_counter / T

            # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
            # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0
            loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
            loss.backward()
            optimizer.step()
            # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
            functional.reset_net(net)

            # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
            accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(
                device)).float().mean().item()

            writer.add_scalar('train_accuracy', accuracy, train_times)
            train_accs.append(accuracy)

            train_times += 1
        net.eval()
        with torch.no_grad():
            # 每遍历一次全部数据集,就在测试集上测试一次
            test_sum = 0
            correct_sum = 0
            for img, label in test_data_loader:
                img = img.to(device)
                for t in range(T):
                    if t == 0:
                        out_spikes_counter = net(encoder(img).float())
                    else:
                        out_spikes_counter += net(encoder(img).float())

                correct_sum += (out_spikes_counter.max(1)[1] == label.to(
                    device)).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)
            test_accuracy = correct_sum / test_sum
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            test_accs.append(test_accuracy)
            max_test_accuracy = max(max_test_accuracy, test_accuracy)
        print(
            f'Epoch {epoch}: device={device}, dataset_dir={dataset_dir}, batch_size={batch_size}, learning_rate={learning_rate}, T={T}, log_dir={log_dir}, max_test_accuracy={max_test_accuracy}, train_times={train_times}'
        )

    # 保存绘图用数据
    net.eval()
    functional.set_monitor(net, True)
    with torch.no_grad():
        img, label = test_dataset[0]
        img = img.to(device)
        for t in range(T):
            if t == 0:
                out_spikes_counter = net(encoder(img).float())
            else:
                out_spikes_counter += net(encoder(img).float())
        out_spikes_counter_frequency = (out_spikes_counter / T).cpu().numpy()
        print(f'Firing rate: {out_spikes_counter_frequency}')
        output_layer = net[-1]  # 输出层
        v_t_array = np.asarray(output_layer.monitor['v']).squeeze(
        ).T  # v_t_array[i][j]表示神经元i在j时刻的电压值
        np.save("v_t_array.npy", v_t_array)
        s_t_array = np.asarray(output_layer.monitor['s']).squeeze(
        ).T  # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1
        np.save("s_t_array.npy", s_t_array)

    train_accs = np.array(train_accs)
    np.save('train_accs.npy', train_accs)
    test_accs = np.array(test_accs)
    np.save('test_accs.npy', test_accs)
def play(use_cuda,
         pt_path,
         env_name,
         hidden_size,
         played_frames=60,
         save_fig_num=0,
         fig_dir=None,
         figsize=(12, 6),
         firing_rates_plot_type='bar',
         heatmap_shape=None):

    T = 16

    plt.rcParams['figure.figsize'] = figsize
    plt.ion()

    env = gym.make(env_name).unwrapped

    device = torch.device("cuda" if use_cuda else "cpu")

    n_states = env.observation_space.shape[0]
    n_actions = env.action_space.n

    policy_net = DQSN(n_states, hidden_size, n_actions, T).to(device)
    policy_net.load_state_dict(torch.load(pt_path, map_location=device))

    env.reset()
    state = torch.zeros([1, n_states], dtype=torch.float, device=device)

    with torch.no_grad():
        functional.set_monitor(policy_net, True)
        delta_lim = 0
        over_score = 1e9

        for i in count():
            LIF_v = policy_net(state)  # shape=[1, 2]
            action = LIF_v.max(1)[1].view(1, 1).item()

            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (1, 0), colspan=3)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (1, 0))

            plt.xticks(np.arange(2), ('Left', 'Right'))
            plt.ylabel('Voltage')
            plt.title('Voltage of LIF neurons at last time step')
            delta_lim = (LIF_v.max() - LIF_v.min()) * 0.5
            plt.ylim(LIF_v.min() - delta_lim, LIF_v.max() + delta_lim)
            plt.yticks([])
            plt.text(0,
                     LIF_v[0][0],
                     str(round(LIF_v[0][0].item(), 2)),
                     ha='center')
            plt.text(1,
                     LIF_v[0][1],
                     str(round(LIF_v[0][1].item(), 2)),
                     ha='center')

            plt.bar(np.arange(2),
                    LIF_v.squeeze(),
                    color=['r', 'gray'] if action == 0 else ['gray', 'r'],
                    width=0.5)

            if LIF_v.min() - delta_lim < 0:
                plt.axhline(0, color='black', linewidth=0.1)

            IF_spikes = np.asarray(
                policy_net.fc[1].monitor['s'])  # shape=[16, 1, 256]
            firing_rates = IF_spikes.mean(axis=0).squeeze()

            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (0, 4), rowspan=2, colspan=5)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (0, 1), rowspan=2, colspan=2)

            plt.title('Firing rates of IF neurons')

            if firing_rates_plot_type == 'bar':
                # 绘制柱状图
                plt.xlabel('Neuron index')
                plt.ylabel('Firing rate')
                plt.xlim(0, firing_rates.size)
                plt.ylim(0, 1.01)
                plt.bar(np.arange(firing_rates.size), firing_rates, width=0.5)

            elif firing_rates_plot_type == 'heatmap':
                # 绘制热力图
                heatmap = plt.imshow(firing_rates.reshape(heatmap_shape),
                                     vmin=0,
                                     vmax=1,
                                     cmap='ocean')
                plt.gca().invert_yaxis()
                cbar = heatmap.figure.colorbar(heatmap)
                cbar.ax.set_ylabel('Magnitude', rotation=90, va='top')

            functional.reset_net(policy_net)
            subtitle = f'Position={state[0][0].item(): .2f}, Velocity={state[0][1].item(): .2f}, Pole Angle={state[0][2].item(): .2f}, Pole Velocity At Tip={state[0][3].item(): .2f}, Score={i}'

            state, reward, done, _ = env.step(action)

            if done:
                over_score = min(over_score, i)
                subtitle = f'Game over, Score={over_score}'
            plt.suptitle(subtitle)

            state = torch.from_numpy(state).float().to(device).unsqueeze(0)
            screen = env.render(mode='rgb_array').copy()
            screen[300, :, :] = 0  # 画出黑线

            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (0, 0), colspan=3)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (0, 0))

            plt.xticks([])
            plt.yticks([])
            plt.title('Game screen')
            plt.imshow(screen, interpolation='bicubic')
            plt.pause(0.001)

            if i < save_fig_num:
                plt.savefig(os.path.join(fig_dir, f'{i}.png'))

            if done and i >= played_frames:
                env.close()
                plt.close()
                break
    '''
    train(use_cuda=False, model_dir='./model/CartPole-v0/state', log_dir='./log', env_name='CartPole-v0', \
            hidden_size=256, num_episodes=500, seed=1)
    '''

    play(use_cuda=False, pt_path='./model/CartPole-v0/policy_net_256_max.pt', env_name='CartPole-v0', \
            hidden_size=256, played_frames=300)
Beispiel #3
0
    # Omitting pruning for all BN layers
    BN_list = ['static_conv.1', 'conv.2', 'conv.5', 'conv.9', 'conv.12', 'conv.15']
    for name, param in net.named_parameters():
        if any(BN_name in name for BN_name in BN_list):
            bn_params += [param]
            ttl_cnt += param.numel()
        else:
            weight_params += [param]
            w_cnt += param.numel()
            ttl_cnt += param.numel()

    ###### TEST MODE ######
    if test:
        with torch.no_grad():
            # Turn on monitor
            set_monitor(net, True)

            # Record total spike times by layer
            spike_times = dict()

            for name, module in net.named_modules():
                if hasattr(module, 'monitor'):
                    spike_times[name] = 0

            test_sum = 0
            correct_sum = 0

            for img, label in test_data_loader:
                img = img.cuda(non_blocking=True)
                label = label.cuda(non_blocking=True)