示例#1
0
def ppo_update(ppo_epochs,
               mini_batch_size,
               states,
               actions,
               log_probs,
               returns,
               advantages,
               clip_param=0.2):
    for _ in range(ppo_epochs):
        for state, action, old_log_probs, return_, advantage in ppo_iter(
                mini_batch_size, states, actions, log_probs, returns,
                advantages):
            dist, value = model(state)
            functional.reset_net(model)
            entropy = dist.entropy().mean()
            new_log_probs = dist.log_prob(action)

            ratio = (new_log_probs - old_log_probs).exp()
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - clip_param,
                                1.0 + clip_param) * advantage

            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = (return_ - value).pow(2).mean()

            loss = 0.5 * critic_loss + actor_loss - 0.001 * entropy

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
示例#2
0
    def optimize_model():
        if len(memory) < BATCH_SIZE:
            return
        transitions = memory.sample(BATCH_SIZE)

        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                           if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        state_action_values = policy_net(state_batch).gather(1, action_batch)

        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
        functional.reset_net(target_net)
        expected_state_action_values = (next_state_values * GAMMA) + reward_batch

        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()
        for param in policy_net.parameters():
            if param.grad is not None:
                param.grad.data.clamp_(-1, 1)
        optimizer.step()
        functional.reset_net(policy_net)
示例#3
0
def test(net, mode, test_loader, device, loss_f, sparse_reg=0):

    net.eval()
    test_loss = 0
    correct_pred = 0

    with torch.no_grad():

        for x, label in test_loader:

            x, label = x.to(device), label.to(device)

            if mode == "ann": y = net(x)
            elif mode == "snn": y, spiking_actv = net(x)

            if loss_f == "mse":
                label = F.one_hot(label, 10).float()
                test_loss += F.mse_loss(y, label)
                correct_pred += (y.argmax(dim=1) == label.argmax(
                    dim=1)).sum().item()
            if loss_f == "ce":
                test_loss += F.cross_entropy(y, label).item()
                pred = y.argmax(dim=1)
                correct_pred += (pred == label).sum().item()

            if mode == "snn":
                functional.reset_net(net)

    test_acc = 100. * correct_pred / len(test_loader.dataset)
    test_loss /= len(test_loader)

    print("===> Test Accuracy : {:.2f}%, Test Average loss: {:.8f}".format(
        test_acc, test_loss))
    return test_loss, test_acc
示例#4
0
def train_full(net, mode, train_loader, optimizer, device, epoch):
    
    net.train()
    correct_pred = 0
    train_loss = 0

    for batch_idx, (x, label) in enumerate(train_loader):
        x, label = x.to(device), label.to(device)
        optimizer.zero_grad()
        y = net(x)
        label = F.one_hot(label, 10).float()
        loss = F.mse_loss(y, label)
        correct_pred += (y.argmax(dim=1) == label.argmax(dim=1)).sum().item()
        #loss = F.cross_entropy(y,label)
        #pred = y.argmax(dim=1)
        #correct_pred += (pred == label).sum().item()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if batch_idx % log_interval == 0:
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

        if mode=="snn":
            functional.reset_net(net)

    train_acc = 100. * correct_pred / len(train_loader.dataset)
    train_loss /= len(train_loader)

    print("\n===> Train Epoch Accuracy : {:.2f}%, , Train Average loss: {:.8f}".format(train_acc, train_loss))
    return train_loss, train_acc
示例#5
0
    def get_values(self, data, targets, device, T, func_dict, **kargs):
        if self.pi:
            if mutex_shared.acquire():
                getattr(self, '_pre_batch_sim')(**kargs)
                mutex_shared.release()
        else:
            getattr(self, '_pre_batch_sim')(**kargs)
        global global_shared
        data = data.to(device)
        targets = targets.to(device)
        values_list = defaultdict(list)

        if self.pi:
            if mutex_schedule.acquire():
                self.global_shared['device_used'][device] = 1
                mutex_schedule.release()

        snn = self.global_shared['distri_model'][device]
        functional.reset_net(snn)
        with torch.no_grad():
            for t in range(T):
                enc = self.encoder(data).float()
                out = snn(enc)
                if t == 0:
                    counter = out
                else:
                    counter += out
                for value_name in func_dict.keys():
                    value = func_dict[value_name](data=data,
                                                  targets=targets,
                                                  out_spike=out,
                                                  out_spike_cnt=counter,
                                                  device=device,
                                                  **kargs)
                    values_list[value_name].append(value)

        for value_name in func_dict.keys():
            values_list[value_name] = np.array(values_list[value_name]).astype(
                np.float32)

        if self.pi:
            if mutex_shared.acquire():
                for value_name in func_dict.keys():
                    self.global_shared[value_name].append(
                        values_list[value_name])
                getattr(self, '_after_batch_sim')(**kargs)
                mutex_shared.release()
        else:
            for value_name in func_dict.keys():
                self.global_shared[value_name].append(values_list[value_name])
            getattr(self, '_after_batch_sim')(**kargs)

        if self.pi:
            if mutex_schedule.acquire():
                self.global_shared['device_used'][device] = 0
                mutex_schedule.release()
示例#6
0
def train(net,
          mode,
          train_loader,
          optimizer,
          device,
          epoch,
          loss_f,
          custom_plasticity=False,
          sparse_reg=0):

    net.train()
    correct_pred, train_loss = 0, 0

    for batch_idx, (x, label) in enumerate(train_loader):

        x, label = x.to(device), label.to(device)
        optimizer.zero_grad()

        if mode == "ann": y = net(x)
        elif mode == "snn": y, spiking_actv = net(x)

        if loss_f == "mse":
            label = F.one_hot(label, 10).float()
            loss = F.mse_loss(y, label)
            correct_pred += (y.argmax(dim=1) == label.argmax(
                dim=1)).sum().item()
        if loss_f == "ce":
            loss = F.cross_entropy(y, label)
            pred = y.argmax(dim=1)
            correct_pred += (pred == label).sum().item()

        if mode == "snn":
            loss += sparse_reg * torch.norm(spiking_actv)

        loss.backward()

        if custom_plasticity:
            for p in net.parameters():
                q1 = p.quantile(0.95).item()
                q2 = (-p).quantile(0.95).item()
                p.grad = (p < q1) * (-p < q2) * p.grad
                #p.grad = (p>=q)*(p.grad/100) + (p<q)*(p.grad)

        optimizer.step()
        train_loss += loss.item()
        if mode == "snn":
            functional.reset_net(net)

    train_acc = 100. * correct_pred / len(train_loader.dataset)
    train_loss /= len(train_loader)

    print(
        "\n===> Train Epoch Accuracy : {:.2f}%, , Train Average loss: {:.8f}".
        format(train_acc, train_loss))
    return train_loss, train_acc
示例#7
0
 def select_action(state, steps_done):
     sample = random.random()
     eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                     math.exp(-1. * steps_done / EPS_DECAY)
     if sample > eps_threshold:
         with torch.no_grad():
             ac = policy_net(state).max(1)[1].view(1, 1)
             functional.reset_net(policy_net)
             return ac
     else:
         return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)
示例#8
0
def test_env(vis=False):
    state = env.reset()
    if vis: env.render()
    done = False
    total_reward = 0
    while not done:
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        dist, _ = model(state)
        functional.reset_net(model)
        next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])
        state = next_state
        if vis: env.render()
        total_reward += reward
    return total_reward
示例#9
0
state = envs.reset()

writer = SummaryWriter(logdir='./log')

while step_idx < max_steps:

    log_probs = []
    values = []
    rewards = []
    masks = []
    entropy = 0

    for _ in range(num_steps):
        state = torch.FloatTensor(state).to(device)
        dist, value = model(state)
        functional.reset_net(model)

        action = dist.sample()
        next_state, reward, done, _ = envs.step(action.cpu().numpy())

        log_prob = dist.log_prob(action)
        entropy += dist.entropy().mean()

        log_probs.append(log_prob)
        values.append(value)
        rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))
        masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))

        state = next_state
        step_idx += 1
示例#10
0
def main():
    '''
    * :ref:`API in English <lif_fc_mnist.main-en>`

    .. _lif_fc_mnist.main-cn:

    :return: None

    使用全连接-LIF-全连接-LIF的网络结构,进行MNIST识别。这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。

    * :ref:`中文API <lif_fc_mnist.main-cn>`

    .. _lif_fc_mnist.main-en:

    The network with FC-LIF-FC-LIF structure for classifying MNIST. This function initials the network, starts training
    and shows accuracy on test dataset.
    '''
    device = input(
        '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
    dataset_dir = input(
        '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": '
    )
    batch_size = int(
        input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
    learning_rate = float(
        input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
    T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
    tau = float(
        input(
            '输入LIF神经元的时间常数tau,例如“100.0”\n input membrane time constant, tau, for LIF neurons, e.g., "100.0": '
        ))
    train_epoch = int(
        input(
            '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": '))
    log_dir = input(
        '输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": '
    )

    writer = SummaryWriter(log_dir)

    # 初始化数据加载器
    train_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    test_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True)

    train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    drop_last=True)
    test_data_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=False,
                                                   drop_last=False)

    # 定义并初始化网络
    net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False),
                        neuron.LIFNode(tau=tau))
    net = net.to(device)
    # 使用Adam优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    # 使用泊松编码器
    encoder = encoding.PoissonEncoder()
    train_times = 0
    max_test_accuracy = 0

    test_accs = []
    train_accs = []

    for epoch in range(train_epoch):
        net.train()
        for img, label in tqdm(train_data_loader):
            img = img.to(device)
            label = label.to(device)
            label_one_hot = F.one_hot(label, 10).float()

            optimizer.zero_grad()

            # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor
            # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数
            for t in range(T):
                if t == 0:
                    out_spikes_counter = net(encoder(img).float())
                else:
                    out_spikes_counter += net(encoder(img).float())

            # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率
            out_spikes_counter_frequency = out_spikes_counter / T

            # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
            # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0
            loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
            loss.backward()
            optimizer.step()
            # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
            functional.reset_net(net)

            # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
            accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(
                device)).float().mean().item()

            writer.add_scalar('train_accuracy', accuracy, train_times)
            train_accs.append(accuracy)

            train_times += 1
        net.eval()
        with torch.no_grad():
            # 每遍历一次全部数据集,就在测试集上测试一次
            test_sum = 0
            correct_sum = 0
            for img, label in test_data_loader:
                img = img.to(device)
                for t in range(T):
                    if t == 0:
                        out_spikes_counter = net(encoder(img).float())
                    else:
                        out_spikes_counter += net(encoder(img).float())

                correct_sum += (out_spikes_counter.max(1)[1] == label.to(
                    device)).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)
            test_accuracy = correct_sum / test_sum
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            test_accs.append(test_accuracy)
            max_test_accuracy = max(max_test_accuracy, test_accuracy)
        print(
            f'Epoch {epoch}: device={device}, dataset_dir={dataset_dir}, batch_size={batch_size}, learning_rate={learning_rate}, T={T}, log_dir={log_dir}, max_test_accuracy={max_test_accuracy}, train_times={train_times}'
        )

    # 保存绘图用数据
    net.eval()
    functional.set_monitor(net, True)
    with torch.no_grad():
        img, label = test_dataset[0]
        img = img.to(device)
        for t in range(T):
            if t == 0:
                out_spikes_counter = net(encoder(img).float())
            else:
                out_spikes_counter += net(encoder(img).float())
        out_spikes_counter_frequency = (out_spikes_counter / T).cpu().numpy()
        print(f'Firing rate: {out_spikes_counter_frequency}')
        output_layer = net[-1]  # 输出层
        v_t_array = np.asarray(output_layer.monitor['v']).squeeze(
        ).T  # v_t_array[i][j]表示神经元i在j时刻的电压值
        np.save("v_t_array.npy", v_t_array)
        s_t_array = np.asarray(output_layer.monitor['s']).squeeze(
        ).T  # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1
        np.save("s_t_array.npy", s_t_array)

    train_accs = np.array(train_accs)
    np.save('train_accs.npy', train_accs)
    test_accs = np.array(test_accs)
    np.save('test_accs.npy', test_accs)
示例#11
0
            for img, label in test_data_loader:
                img = img.cuda(non_blocking=True)
                label = label.cuda(non_blocking=True)

                out_spikes_counter = net(img)

                correct_sum += (out_spikes_counter.argmax(dim=1) == label).float().sum().item()
                test_sum += label.numel()

                for name, module in net.named_modules():
                    if hasattr(module, 'monitor'):
                        # monitor['s'] is a list, each element is of shape [batch_size, ...]
                        spike_times[name] += torch.sum(torch.from_numpy(np.concatenate(module.monitor['s'], axis=0)).cuda(), dim=0)

                reset_net(net)

            test_accuracy = correct_sum / test_sum

############ 1. Firing Rate ###########
            print('Firing Rates:')
            for k, v in spike_times.items():
                rate = (v / (T * len(test_dataset))).flatten().cpu().numpy()

                if no_prune:
                    filename = 'rate-' + k + '-no_prune.npy'  
                else:
                    filename = 'rate-' + k + '-' + np.format_float_scientific(penalty, exp_digits=1, trim='-') + '.npy'

                with open(os.path.join(log_dir, filename), 'wb') as f:
                    np.save(f, rate)
示例#12
0
                with amp.autocast():
                    y = net(frame)
                    loss = F.mse_loss(y, label_onehot)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                y = net(frame)
                loss = F.mse_loss(y, label_onehot)
                loss.backward()
                optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (y.argmax(1) == label).float().sum().item()
            functional.reset_net(net)

        if lr_scheduler_type != "None":
            lr_scheduler.step()
        end = time.time()
        train_loss /= train_samples
        train_acc /= train_samples
        logs["train_loss"].append(train_loss)
        logs["train_acc"].append(train_acc)

        net.eval()
        test_loss, test_acc, test_samples = 0, 0, 0
        with torch.no_grad():
            for frame, label in test_loader:
                frame, label = frame.float().to(device), label.to(device)
                label_onehot = F.one_hot(label, 11).float()
def main():
    '''
    * :ref:`API in English <conv_fashion_mnist_cuda_lbl.main-en>`

    .. _conv_fashion_mnist_cuda_lbl.main-cn:

    :return: None

    :class:`spikingjelly.clock_driven.examples.conv_fashion_mnist` 的逐层传播版本。

    训练100个epoch,训练batch和测试集上的正确率如下:

    .. image:: ./_static/tutorials/clock_driven/11_cext_neuron_with_lbl/train.*
        :width: 100%

    .. image:: ./_static/tutorials/clock_driven/11_cext_neuron_with_lbl/test.*
        :width: 100%

    * :ref:`中文API <conv_fashion_mnist_cuda_lbl.main-cn>`

    .. _conv_fashion_mnist_cuda_lbl.main-en:

    The layer-by-layer version of :class:`spikingjelly.clock_driven.examples.conv_fashion_mnist`.

    After 100 epochs, the accuracy on train batch and test dataset is as followed:

    .. image:: ./_static/tutorials/clock_driven/11_cext_neuron_with_lbl/train.*
        :width: 100%

    .. image:: ./_static/tutorials/clock_driven/11_cext_neuron_with_lbl/test.*
        :width: 100%
    '''
    device = input('输入运行的GPU,例如“cuda:0”\n input GPU index, e.g., "cuda:0": ')
    if device == 'cpu':
        print("conv_fashion_mnist_cuda_lbl only supports GPU.")
        exit()
    dataset_dir = input(
        '输入保存Fashion MNIST数据集的位置,例如“./”\n input root directory for saving Fashion MNIST dataset, e.g., "./": '
    )
    batch_size = int(
        input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
    learning_rate = float(
        input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
    T = int(input('输入仿真时长,例如“8”\n input simulating steps, e.g., "8": '))
    tau = float(
        input(
            '输入LIF神经元的时间常数tau,例如“2.0”\n input membrane time constant, tau, for LIF neurons, e.g., "2.0": '
        ))
    train_epoch = int(
        input(
            '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": '))
    log_dir = input(
        '输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": '
    )
    # device = 'cuda:0'
    # dataset_dir = './'
    # batch_size = 128
    # learning_rate = 1e-3
    # T = 8
    # tau = 2.0
    # train_epoch = 100
    # log_dir = './logs2'

    writer = SummaryWriter(log_dir)

    # 初始化数据加载器
    train_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.FashionMNIST(
            root=dataset_dir,
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.FashionMNIST(
            root=dataset_dir,
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=batch_size,
        shuffle=True,
        drop_last=False)

    # 初始化网络
    net = Net(tau=tau, T=T).to(device)
    # 使用Adam优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    train_times = 0
    max_test_accuracy = 0
    for epoch in range(train_epoch):
        net.train()
        t_start = time.perf_counter()
        for img, label in train_data_loader:
            img = img.to(device)
            label = label.to(device)
            label_one_hot = F.one_hot(label, 10).float()

            optimizer.zero_grad()

            out_spikes_counter_frequency = net(img)

            # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
            # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0
            loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
            loss.backward()
            optimizer.step()
            # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
            functional.reset_net(net)

            # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
            accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(
                device)).float().mean().item()
            if train_times % 256 == 0:
                writer.add_scalar('train_accuracy', accuracy, train_times)
            train_times += 1
        t_train = time.perf_counter() - t_start
        net.eval()
        t_start = time.perf_counter()
        with torch.no_grad():
            # 每遍历一次全部数据集,就在测试集上测试一次
            test_sum = 0
            correct_sum = 0
            for img, label in test_data_loader:
                img = img.to(device)
                out_spikes_counter_frequency = net(img)

                correct_sum += (out_spikes_counter_frequency.max(1)[1] ==
                                label.to(device)).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)
            test_accuracy = correct_sum / test_sum
            t_test = time.perf_counter() - t_start
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            if max_test_accuracy < test_accuracy:
                max_test_accuracy = test_accuracy
                print('saving net...')
                torch.save(net, log_dir + '/net_max_acc.pt')
                print('saved')

        print(
            'epoch={}, t_train={}, t_test={}, device={}, dataset_dir={}, batch_size={}, learning_rate={}, T={}, log_dir={}, max_test_accuracy={}, train_times={}'
            .format(epoch, t_train, t_test, device, dataset_dir, batch_size,
                    learning_rate, T, log_dir, max_test_accuracy, train_times))
示例#14
0
def simulate_snn(snn,
                 device,
                 data_loader,
                 T,
                 poisson=False,
                 online_draw=False,
                 fig_name='default',
                 ann_baseline=0,
                 save_acc_list=False,
                 log_dir=None):  # TODO ugly
    '''
    * :ref:`API in English <simulate_snn-en>`

    .. _simulate_snn-cn:

    :param snn: SNN模型
    :param device: 运行的设备
    :param data_loader: 测试数据加载器
    :param T: 仿真时长
    :param poisson: 当设置为 ``True`` ,输入采用泊松编码器;否则,采用恒定输入并持续T时间步
    :return: SNN模拟的准确率

    对SNN分类性能进行评估,并返回准确率

    * :ref:`中文API <simulate_snn-cn>`

    .. _simulate_snn-en:

    :param snn: SNN model
    :param device: running device
    :param data_loader: testing data loader
    :param T: simulating steps
    :param poisson: when ``True``, use poisson encoder; otherwise, use constant input over T steps
    :return: SNN simulating accuracy

    '''
    functional.reset_net(snn)
    if poisson:
        encoder = encoding.PoissonEncoder()
    correct_t = {}
    if online_draw:
        plt.ion()
    with torch.no_grad():
        snn.eval()
        correct = 0.0
        total = 0.0
        for batch, (img, label) in enumerate(data_loader):
            img = img.to(device)
            for t in tqdm.tqdm(range(T)):
                encoded = encoder(img).float() if poisson else img
                out = snn(encoded)
                if isinstance(out, tuple) or isinstance(out, list):
                    out = out[0]
                if t == 0:
                    out_spikes_counter = out
                else:
                    out_spikes_counter += out

                if t not in correct_t.keys():
                    correct_t[t] = (out_spikes_counter.max(1)[1] == label.to(
                        device)).float().sum().item()
                else:
                    correct_t[t] += (out_spikes_counter.max(1)[1] == label.to(
                        device)).float().sum().item()
            correct += (out_spikes_counter.max(1)[1] == label.to(device)
                        ).float().sum().item()
            total += label.numel()
            functional.reset_net(snn)
            if online_draw:
                plt.cla()
                x = np.array(list(correct_t.keys())).astype(np.float32) + 1
                y = np.array(list(correct_t.values())).astype(
                    np.float32) / total * 100
                plt.plot(x, y, label='SNN', c='b')
                if ann_baseline != 0:
                    plt.plot(x,
                             np.ones_like(x) * ann_baseline,
                             label='ANN',
                             c='g',
                             linestyle=':')
                    plt.text(0,
                             ann_baseline + 1,
                             "%.3f%%" % (ann_baseline),
                             fontdict={
                                 'size': '8',
                                 'color': 'g'
                             })
                plt.title("%s SNN Simulation \n[test samples:%.1f%%]" %
                          (fig_name, 100 * total / len(data_loader.dataset)))
                plt.xlabel("T")
                plt.ylabel("Accuracy(%)")
                plt.legend()
                argmax = np.argmax(y)
                disp_bias = 0.3 * float(T) if x[argmax] / T > 0.7 else 0
                plt.text(x[argmax] - 0.8 - disp_bias,
                         y[argmax] + 0.8,
                         "MAX:%.3f%% T=%d" % (y[argmax], x[argmax]),
                         fontdict={
                             'size': '12',
                             'color': 'r'
                         })

                plt.scatter([x[argmax]], [y[argmax]], c='r')
                plt.pause(0.01)
                if isinstance(log_dir, str):
                    plt.savefig(log_dir + '/' + fig_name + ".pdf")
            print('[SNN Simulating... %.2f%%] Acc:%.3f' %
                  (100 * total / len(data_loader.dataset), correct / total))
            if save_acc_list:
                acc_list = np.array(list(correct_t.values())).astype(
                    np.float32) / total * 100
                np.save(
                    log_dir + '/snn_acc-list' +
                    ('-poisson' if poisson else '-constant'), acc_list)
        acc = correct / total
        print('SNN Simulating Accuracy:%.3f' % (acc))

    if online_draw:
        plt.savefig(log_dir + '/' + fig_name + ".pdf")
        plt.ioff()
        plt.close()
    return acc
def main():
    '''
    * :ref:`API in English <conv_fashion_mnist.main-en>`

    .. _conv_fashion_mnist.main-cn:

    :return: None

    使用卷积-全连接的网络结构,进行Fashion MNIST识别。这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。会将训练过
    程中测试集正确率最高的网络保存在 ``tensorboard`` 日志文件的同级目录下。这个目录的位置,是在运行 ``main()``
    函数时由用户输入的。

    训练100个epoch,训练batch和测试集上的正确率如下:

    .. image:: ./_static/tutorials/clock_driven/4_conv_fashion_mnist/train.*
        :width: 100%

    .. image:: ./_static/tutorials/clock_driven/4_conv_fashion_mnist/test.*
        :width: 100%

    * :ref:`中文API <conv_fashion_mnist.main-cn>`

    .. _conv_fashion_mnist.main-en:

    The network with Conv-FC structure for classifying Fashion MNIST. This function initials the network, starts training
    and shows accuracy on test dataset. The net with the max accuracy on test dataset will be saved in
    the root directory for saving ``tensorboard`` logs, which is inputted by user when running the ``main()``  function.

    After 100 epochs, the accuracy on train batch and test dataset is as followed:

    .. image:: ./_static/tutorials/clock_driven/4_conv_fashion_mnist/train.*
        :width: 100%

    .. image:: ./_static/tutorials/clock_driven/4_conv_fashion_mnist/test.*
        :width: 100%
    '''
    device = input(
        '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
    dataset_dir = input(
        '输入保存Fashion MNIST数据集的位置,例如“./”\n input root directory for saving Fashion MNIST dataset, e.g., "./": '
    )
    batch_size = int(
        input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
    learning_rate = float(
        input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
    T = int(input('输入仿真时长,例如“8”\n input simulating steps, e.g., "8": '))
    tau = float(
        input(
            '输入LIF神经元的时间常数tau,例如“2.0”\n input membrane time constant, tau, for LIF neurons, e.g., "2.0": '
        ))
    train_epoch = int(
        input(
            '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": '))
    log_dir = input(
        '输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": '
    )

    writer = SummaryWriter(log_dir)

    # 初始化数据加载器
    train_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.FashionMNIST(
            root=dataset_dir,
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.FashionMNIST(
            root=dataset_dir,
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=batch_size,
        shuffle=True,
        drop_last=False)

    # 初始化网络
    net = Net(tau=tau, T=T).to(device)
    # 使用Adam优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    train_times = 0
    max_test_accuracy = 0
    for epoch in range(train_epoch):
        net.train()
        for img, label in train_data_loader:
            img = img.to(device)
            label = label.to(device)
            label_one_hot = F.one_hot(label, 10).float()

            optimizer.zero_grad()

            out_spikes_counter_frequency = net(img)

            # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
            # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0
            loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
            loss.backward()
            optimizer.step()
            # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
            functional.reset_net(net)

            # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
            accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(
                device)).float().mean().item()
            if train_times % 256 == 0:
                writer.add_scalar('train_accuracy', accuracy, train_times)
            train_times += 1
        net.eval()
        with torch.no_grad():
            # 每遍历一次全部数据集,就在测试集上测试一次
            test_sum = 0
            correct_sum = 0
            for img, label in test_data_loader:
                img = img.to(device)
                out_spikes_counter_frequency = net(img)

                correct_sum += (out_spikes_counter_frequency.max(1)[1] ==
                                label.to(device)).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)
            test_accuracy = correct_sum / test_sum
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            if max_test_accuracy < test_accuracy:
                max_test_accuracy = test_accuracy
                print('saving net...')
                torch.save(net, log_dir + '/net_max_acc.pt')
                print('saved')

        print(
            'device={}, dataset_dir={}, batch_size={}, learning_rate={}, T={}, log_dir={}, max_test_accuracy={}, train_times={}'
            .format(device, dataset_dir, batch_size, learning_rate, T, log_dir,
                    max_test_accuracy, train_times))
示例#16
0
def main():
    # python classify_dvsg.py -data_dir /userhome/datasets/DVS128Gesture -out_dir ./logs -amp -opt Adam -device cuda:0 -lr_scheduler CosALR -T_max 64 -cext -epochs 1024
    '''
    * :ref:`API in English <classify_dvsg.__init__-en>`

    .. _classify_dvsg.__init__-cn:

    用于分类DVS128 Gesture数据集的代码样例。网络结构来自于 `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_。

    .. code:: bash

        usage: classify_dvsg.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-channels CHANNELS] [-data_dir DATA_DIR] [-out_dir OUT_DIR] [-resume RESUME] [-amp] [-cext] [-opt OPT] [-lr LR] [-momentum MOMENTUM] [-lr_scheduler LR_SCHEDULER] [-step_size STEP_SIZE] [-gamma GAMMA] [-T_max T_MAX]

        Classify DVS128 Gesture

        optional arguments:
          -h, --help            show this help message and exit
          -T T                  simulating time-steps
          -device DEVICE        device
          -b B                  batch size
          -epochs N             number of total epochs to run
          -j N                  number of data loading workers (default: 4)
          -channels CHANNELS    channels of Conv2d in SNN
          -data_dir DATA_DIR    root dir of DVS128 Gesture dataset
          -out_dir OUT_DIR      root dir for saving logs and checkpoint
          -resume RESUME        resume from the checkpoint path
          -amp                  automatic mixed precision training
          -cext                 use CUDA neuron and multi-step forward mode
          -opt OPT              use which optimizer. SDG or Adam
          -lr LR                learning rate
          -momentum MOMENTUM    momentum for SGD
          -lr_scheduler LR_SCHEDULER
                                use which schedule. StepLR or CosALR
          -step_size STEP_SIZE  step_size for StepLR
          -gamma GAMMA          gamma for StepLR
          -T_max T_MAX          T_max for CosineAnnealingLR

    运行示例:

    .. code:: bash

        python -m spikingjelly.clock_driven.examples.classify_dvsg -data_dir /userhome/datasets/DVS128Gesture -out_dir ./logs -amp -opt Adam -device cuda:0 -lr_scheduler CosALR -T_max 64 -cext -epochs 1024

    阅读教程 :doc:`./clock_driven/14_classify_dvsg` 以获得更多信息。

    * :ref:`中文API <classify_dvsg.__init__-cn>`

    .. _classify_dvsg.__init__-en:

    The code example for classifying the DVS128 Gesture dataset. The network structure is from `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_.


    .. code:: bash

        usage: classify_dvsg.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-channels CHANNELS] [-data_dir DATA_DIR] [-out_dir OUT_DIR] [-resume RESUME] [-amp] [-cext] [-opt OPT] [-lr LR] [-momentum MOMENTUM] [-lr_scheduler LR_SCHEDULER] [-step_size STEP_SIZE] [-gamma GAMMA] [-T_max T_MAX]

        Classify DVS128 Gesture

        optional arguments:
          -h, --help            show this help message and exit
          -T T                  simulating time-steps
          -device DEVICE        device
          -b B                  batch size
          -epochs N             number of total epochs to run
          -j N                  number of data loading workers (default: 4)
          -channels CHANNELS    channels of Conv2d in SNN
          -data_dir DATA_DIR    root dir of DVS128 Gesture dataset
          -out_dir OUT_DIR      root dir for saving logs and checkpoint
          -resume RESUME        resume from the checkpoint path
          -amp                  automatic mixed precision training
          -cext                 use CUDA neuron and multi-step forward mode
          -opt OPT              use which optimizer. SDG or Adam
          -lr LR                learning rate
          -momentum MOMENTUM    momentum for SGD
          -lr_scheduler LR_SCHEDULER
                                use which schedule. StepLR or CosALR
          -step_size STEP_SIZE  step_size for StepLR
          -gamma GAMMA          gamma for StepLR
          -T_max T_MAX          T_max for CosineAnnealingLR

    Running Example:

    .. code:: bash

        python -m spikingjelly.clock_driven.examples.classify_dvsg -data_dir /userhome/datasets/DVS128Gesture -out_dir ./logs -amp -opt Adam -device cuda:0 -lr_scheduler CosALR -T_max 64 -cext -epochs 1024

    See the tutorial :doc:`./clock_driven_en/14_classify_dvsg` for more details.
    '''
    parser = argparse.ArgumentParser(description='Classify DVS128 Gesture')
    parser.add_argument('-T',
                        default=16,
                        type=int,
                        help='simulating time-steps')
    parser.add_argument('-device', default='cuda:0', help='device')
    parser.add_argument('-b', default=16, type=int, help='batch size')
    parser.add_argument('-epochs',
                        default=64,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-channels',
                        default=128,
                        type=int,
                        help='channels of Conv2d in SNN')
    parser.add_argument('-data_dir',
                        type=str,
                        help='root dir of DVS128 Gesture dataset')
    parser.add_argument('-out_dir',
                        type=str,
                        help='root dir for saving logs and checkpoint')

    parser.add_argument('-resume',
                        type=str,
                        help='resume from the checkpoint path')
    parser.add_argument('-amp',
                        action='store_true',
                        help='automatic mixed precision training')
    parser.add_argument('-cext',
                        action='store_true',
                        help='use CUDA neuron and multi-step forward mode')

    parser.add_argument('-opt',
                        type=str,
                        help='use which optimizer. SDG or Adam')
    parser.add_argument('-lr', default=0.001, type=float, help='learning rate')
    parser.add_argument('-momentum',
                        default=0.9,
                        type=float,
                        help='momentum for SGD')
    parser.add_argument('-lr_scheduler',
                        default='CosALR',
                        type=str,
                        help='use which schedule. StepLR or CosALR')
    parser.add_argument('-step_size',
                        default=32,
                        type=float,
                        help='step_size for StepLR')
    parser.add_argument('-gamma',
                        default=0.1,
                        type=float,
                        help='gamma for StepLR')
    parser.add_argument('-T_max',
                        default=32,
                        type=int,
                        help='T_max for CosineAnnealingLR')

    args = parser.parse_args()
    print(args)

    if args.cext:
        net = CextNet(channels=args.channels)
    else:
        net = PythonNet(channels=args.channels)
    print(net)
    net.to(args.device)

    optimizer = None
    if args.opt == 'SGD':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.opt == 'Adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(args.opt)

    lr_scheduler = None
    if args.lr_scheduler == 'StepLR':
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.step_size, gamma=args.gamma)
    elif args.lr_scheduler == 'CosALR':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.T_max)
    else:
        raise NotImplementedError(args.lr_scheduler)

    train_set = DVS128Gesture(args.data_dir,
                              train=True,
                              data_type='frame',
                              split_by='number',
                              frames_number=args.T)
    test_set = DVS128Gesture(args.data_dir,
                             train=False,
                             data_type='frame',
                             split_by='number',
                             frames_number=args.T)

    train_data_loader = DataLoader(dataset=train_set,
                                   batch_size=args.b,
                                   shuffle=True,
                                   num_workers=args.j,
                                   drop_last=True,
                                   pin_memory=True)

    test_data_loader = DataLoader(dataset=test_set,
                                  batch_size=args.b,
                                  shuffle=False,
                                  num_workers=args.j,
                                  drop_last=False,
                                  pin_memory=True)

    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    start_epoch = 0
    max_test_acc = 0

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']

    out_dir = os.path.join(
        args.out_dir,
        f'T_{args.T}_b_{args.b}_c_{args.channels}_{args.opt}_lr_{args.lr}_')
    if args.lr_scheduler == 'CosALR':
        out_dir += f'CosALR_{args.T_max}'
    elif args.lr_scheduler == 'StepLR':
        out_dir += f'StepLR_{args.step_size}_{args.gamma}'
    else:
        raise NotImplementedError(args.lr_scheduler)

    if args.amp:
        out_dir += '_amp'
    if args.cext:
        out_dir += '_cext'

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
        print(f'Mkdir {out_dir}.')

    with open(os.path.join(out_dir, 'args.txt'), 'w',
              encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    writer = SummaryWriter(os.path.join(out_dir, 'dvsg_logs'),
                           purge_step=start_epoch)

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for frame, label in train_data_loader:
            optimizer.zero_grad()
            frame = frame.float().to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 11).float()
            if args.amp:
                with amp.autocast():
                    out_fr = net(frame)
                    loss = F.mse_loss(out_fr, label_onehot)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                out_fr = net(frame)
                loss = F.mse_loss(out_fr, label_onehot)
                loss.backward()
                optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        lr_scheduler.step()

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for frame, label in test_data_loader:
                frame = frame.float().to(args.device)
                label = label.to(args.device)
                label_onehot = F.one_hot(label, 11).float()
                out_fr = net(frame)
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)

        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        print(args)
        print(
            f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, total_time={time.time() - start_time}'
        )
示例#17
0
def play(use_cuda,
         pt_path,
         env_name,
         hidden_size,
         played_frames=60,
         save_fig_num=0,
         fig_dir=None,
         figsize=(12, 6),
         firing_rates_plot_type='bar',
         heatmap_shape=None):

    T = 16

    plt.rcParams['figure.figsize'] = figsize
    plt.ion()

    env = gym.make(env_name).unwrapped

    device = torch.device("cuda" if use_cuda else "cpu")

    n_states = env.observation_space.shape[0]
    n_actions = env.action_space.n

    policy_net = DQSN(n_states, hidden_size, n_actions, T).to(device)
    policy_net.load_state_dict(torch.load(pt_path, map_location=device))

    env.reset()
    state = torch.zeros([1, n_states], dtype=torch.float, device=device)

    with torch.no_grad():
        functional.set_monitor(policy_net, True)
        delta_lim = 0
        over_score = 1e9

        for i in count():
            LIF_v = policy_net(state)  # shape=[1, 2]
            action = LIF_v.max(1)[1].view(1, 1).item()

            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (1, 0), colspan=3)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (1, 0))

            plt.xticks(np.arange(2), ('Left', 'Right'))
            plt.ylabel('Voltage')
            plt.title('Voltage of LIF neurons at last time step')
            delta_lim = (LIF_v.max() - LIF_v.min()) * 0.5
            plt.ylim(LIF_v.min() - delta_lim, LIF_v.max() + delta_lim)
            plt.yticks([])
            plt.text(0,
                     LIF_v[0][0],
                     str(round(LIF_v[0][0].item(), 2)),
                     ha='center')
            plt.text(1,
                     LIF_v[0][1],
                     str(round(LIF_v[0][1].item(), 2)),
                     ha='center')

            plt.bar(np.arange(2),
                    LIF_v.squeeze(),
                    color=['r', 'gray'] if action == 0 else ['gray', 'r'],
                    width=0.5)

            if LIF_v.min() - delta_lim < 0:
                plt.axhline(0, color='black', linewidth=0.1)

            IF_spikes = np.asarray(
                policy_net.fc[1].monitor['s'])  # shape=[16, 1, 256]
            firing_rates = IF_spikes.mean(axis=0).squeeze()

            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (0, 4), rowspan=2, colspan=5)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (0, 1), rowspan=2, colspan=2)

            plt.title('Firing rates of IF neurons')

            if firing_rates_plot_type == 'bar':
                # 绘制柱状图
                plt.xlabel('Neuron index')
                plt.ylabel('Firing rate')
                plt.xlim(0, firing_rates.size)
                plt.ylim(0, 1.01)
                plt.bar(np.arange(firing_rates.size), firing_rates, width=0.5)

            elif firing_rates_plot_type == 'heatmap':
                # 绘制热力图
                heatmap = plt.imshow(firing_rates.reshape(heatmap_shape),
                                     vmin=0,
                                     vmax=1,
                                     cmap='ocean')
                plt.gca().invert_yaxis()
                cbar = heatmap.figure.colorbar(heatmap)
                cbar.ax.set_ylabel('Magnitude', rotation=90, va='top')

            functional.reset_net(policy_net)
            subtitle = f'Position={state[0][0].item(): .2f}, Velocity={state[0][1].item(): .2f}, Pole Angle={state[0][2].item(): .2f}, Pole Velocity At Tip={state[0][3].item(): .2f}, Score={i}'

            state, reward, done, _ = env.step(action)

            if done:
                over_score = min(over_score, i)
                subtitle = f'Game over, Score={over_score}'
            plt.suptitle(subtitle)

            state = torch.from_numpy(state).float().to(device).unsqueeze(0)
            screen = env.render(mode='rgb_array').copy()
            screen[300, :, :] = 0  # 画出黑线

            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (0, 0), colspan=3)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (0, 0))

            plt.xticks([])
            plt.yticks([])
            plt.title('Game screen')
            plt.imshow(screen, interpolation='bicubic')
            plt.pause(0.001)

            if i < save_fig_num:
                plt.savefig(os.path.join(fig_dir, f'{i}.png'))

            if done and i >= played_frames:
                env.close()
                plt.close()
                break
    '''
    train(use_cuda=False, model_dir='./model/CartPole-v0/state', log_dir='./log', env_name='CartPole-v0', \
            hidden_size=256, num_episodes=500, seed=1)
    '''

    play(use_cuda=False, pt_path='./model/CartPole-v0/policy_net_256_max.pt', env_name='CartPole-v0', \
            hidden_size=256, played_frames=300)
示例#18
0
def main_func(lr, batch_size, freeze_conv, reinit_fc):

    fs_exp_name = f"lr{lr}_bs{batch_size}_frzconv{freeze_conv}_reinitfc{reinit_fc}"
    fs_exp_path = os.path.join("Exps", exp_name, "logs", fs_exp_name)

    print(
        f"\n+======================= {fs_exp_name}   ==========================+\n"
    )

    if not os.path.exists(fs_exp_path):
        os.mkdir(fs_exp_path)

    for k in k_shots:
        print(f"==> {k}-shot ...")
        logs = {
            "train_acc": [],
            "train_loss": [],
            "test_acc": [],
            "test_loss": [],
            "pre_acc": []
        }
        for seed in seeds:
            utils.set_seed(seed)
            old_stdout = sys.stdout
            sys.stdout = open(os.devnull, "w")
            dataset_train = dataset.dataset_prepare_fewshot(few_shot_classes,
                                                            k,
                                                            train=True)
            dataset_test = dataset.dataset_prepare_fewshot(few_shot_classes,
                                                           k,
                                                           train=False)
            sys.stdout = old_stdout

            train_loader = torch.utils.data.DataLoader(dataset_train,
                                                       batch_size,
                                                       shuffle=True,
                                                       pin_memory=True)
            test_loader = torch.utils.data.DataLoader(dataset_test,
                                                      batch_size,
                                                      pin_memory=True)

            if use_cext:
                net = models.CextNet().to(device)
            else:
                net = models.SJSNN().to(device)

            net.load_state_dict(
                torch.load(
                    os.path.join("Exps", exp_name, "model_weights",
                                 "pretrain_net.pth")))

            if freeze_conv:
                for param in net.conv.parameters():
                    param.requires_grad = False
            for layer_ in net.fc.children():
                if hasattr(layer_,
                           'reset_parameters') and reinit_fc and isinstance(
                               layer_, layer.SeqToANNContainer):
                    layer_.module.reset_parameters()

            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                net.parameters()),
                                         lr=lr)
            if lr_scheduler_type == "CosAlr":
                lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, T_max=64)
            if use_amp:
                scaler = amp.GradScaler()

            # Try printing all config things for a check

            net.train()
            for frame, label in train_loader:
                optimizer.zero_grad()
                frame, label = frame.float().to(device), label.to(device)
                label_onehot = F.one_hot(label, 11).float()
                if use_amp:
                    with amp.autocast():
                        y = net(frame)
                        loss = F.mse_loss(y, label_onehot)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    y = net(frame)
                    loss = F.mse_loss(y, label_onehot)
                    loss.backward()
                    optimizer.step()

                functional.reset_net(net)

            if lr_scheduler_type != "None":
                lr_scheduler.step()

            net.eval()
            train_loss, train_acc, train_samples = 0, 0, 0
            with torch.no_grad():
                for frame, label in train_loader:
                    frame, label = frame.float().to(device), label.to(device)
                    label_onehot = F.one_hot(label, 11).float()

                    y = net(frame)
                    loss = F.mse_loss(y, label_onehot)

                    train_samples += label.numel()
                    train_loss += loss.item() * label.numel()
                    train_acc += (y.argmax(1) == label).float().sum().item()
                    functional.reset_net(net)
            train_loss /= train_samples
            train_acc /= train_samples
            logs["train_loss"].append(train_loss)
            logs["train_acc"].append(train_acc)

            test_loss, test_acc, test_samples = 0, 0, 0
            with torch.no_grad():
                for frame, label in test_loader:
                    frame, label = frame.float().to(device), label.to(device)
                    label_onehot = F.one_hot(label, 11).float()

                    y = net(frame)
                    loss = F.mse_loss(y, label_onehot)

                    test_samples += label.numel()
                    test_loss += loss.item() * label.numel()
                    test_acc += (y.argmax(1) == label).float().sum().item()
                    functional.reset_net(net)
            test_loss /= test_samples
            test_acc /= test_samples
            logs["test_loss"].append(test_loss)
            logs["test_acc"].append(test_acc)

            pre_acc, pre_samples = 0, 0
            with torch.no_grad():
                for frame, label in pre_loader:
                    frame, label = frame.float().to(device), label.to(device)
                    label_onehot = F.one_hot(label, 11).float()

                    y = net(frame)
                    loss = F.mse_loss(y, label_onehot)

                    pre_samples += label.numel()
                    pre_acc += (y.argmax(1) == label).float().sum().item()
                    functional.reset_net(net)
            pre_acc /= pre_samples
            logs["pre_acc"].append(pre_acc)

            #print(f"===> {k}-shot with seed : {seed} : Train Loss = {1000*train_loss:.8f}, Train Accuracy = {100*train_acc:.2f}%, Test Accuracy = {100*test_acc:.2f}%, Pre Accuracy = {100*pre_acc:.2f}%")

        with open(os.path.join(fs_exp_path, f"logs_{k}-shot.pickle"),
                  "wb") as file:
            pickle.dump(logs, file)