Exemplo n.º 1
0
    def ppo_update(ppo_epochs,
                   mini_batch_size,
                   states,
                   actions,
                   log_probs,
                   returns,
                   advantages,
                   clip_param=0.2):
        for _ in range(ppo_epochs):
            for state, action, old_log_probs, return_, advantage in ppo_iter(
                    mini_batch_size, states, actions, log_probs, returns,
                    advantages):
                dist, value = model(state)
                functional.reset_net(model)
                entropy = dist.entropy().mean()
                new_log_probs = dist.log_prob(action)

                ratio = (new_log_probs - old_log_probs).exp()
                surr1 = ratio * advantage
                surr2 = torch.clamp(ratio, 1.0 - clip_param,
                                    1.0 + clip_param) * advantage

                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = (return_ - value).pow(2).mean()

                loss = 0.5 * critic_loss + actor_loss - 0.001 * entropy

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
Exemplo n.º 2
0
    def optimize_model():
        if len(memory) < BATCH_SIZE:
            return
        transitions = memory.sample(BATCH_SIZE)

        batch = Transition(*zip(*transitions))

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                           if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        state_action_values = policy_net(state_batch).gather(1, action_batch)

        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
        functional.reset_net(target_net)
        expected_state_action_values = (next_state_values * GAMMA) + reward_batch

        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

        optimizer.zero_grad()
        loss.backward()
        for param in policy_net.parameters():
            if param.grad is not None:
                param.grad.data.clamp_(-1, 1)
        optimizer.step()
        functional.reset_net(policy_net)
Exemplo n.º 3
0
 def select_action(state, steps_done):
     sample = random.random()
     eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                     math.exp(-1. * steps_done / EPS_DECAY)
     if sample > eps_threshold:
         with torch.no_grad():
             ac = policy_net(state).max(1)[1].view(1, 1)
             functional.reset_net(policy_net)
             return ac
     else:
         return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)
Exemplo n.º 4
0
 def test_env(vis=False):
     state = env.reset()
     if vis: env.render()
     done = False
     total_reward = 0
     while not done:
         state = torch.FloatTensor(state).unsqueeze(0).to(device)
         dist, _ = model(state)
         functional.reset_net(model)
         next_state, reward, done, _ = env.step(
             dist.sample().cpu().numpy()[0])
         state = next_state
         if vis: env.render()
         total_reward += reward
     return total_reward
def main():
    '''
    (sj-dev) wfang@Precision-5820-Tower-X-Series:~/spikingjelly_dev$ python -m spikingjelly.activation_based.examples.conv_fashion_mnist -h

    usage: conv_fashion_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR]
                                 [-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR]

    Classify Fashion-MNIST

    optional arguments:
      -h, --help          show this help message and exit
      -T T                simulating time-steps
      -device DEVICE      device
      -b B                batch size
      -epochs N           number of total epochs to run
      -j N                number of data loading workers (default: 4)
      -data-dir DATA_DIR  root dir of Fashion-MNIST dataset
      -out-dir OUT_DIR    root dir for saving logs and checkpoint
      -resume RESUME      resume from the checkpoint path
      -amp                automatic mixed precision training
      -cupy               use cupy neuron and multi-step forward mode
      -opt OPT            use which optimizer. SDG or Adam
      -momentum MOMENTUM  momentum for SGD
      -save-es            dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}
    '''
    # python -m spikingjelly.activation_based.examples.conv_fashion_mnist -T 4 -device cuda:0 -b 128 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8

    # python -m spikingjelly.activation_based.examples.conv_fashion_mnist -T 4 -device cuda:0 -b 4 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -resume ./logs/T4_b256_sgd_lr0.1_c128_amp_cupy/checkpoint_latest.pth -save-es ./logs
    parser = argparse.ArgumentParser(description='Classify Fashion-MNIST')
    parser.add_argument('-T',
                        default=4,
                        type=int,
                        help='simulating time-steps')
    parser.add_argument('-device', default='cuda:0', help='device')
    parser.add_argument('-b', default=128, type=int, help='batch size')
    parser.add_argument('-epochs',
                        default=64,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-data-dir',
                        type=str,
                        help='root dir of Fashion-MNIST dataset')
    parser.add_argument('-out-dir',
                        type=str,
                        default='./logs',
                        help='root dir for saving logs and checkpoint')
    parser.add_argument('-resume',
                        type=str,
                        help='resume from the checkpoint path')
    parser.add_argument('-amp',
                        action='store_true',
                        help='automatic mixed precision training')
    parser.add_argument('-cupy', action='store_true', help='use cupy backend')
    parser.add_argument('-opt',
                        type=str,
                        help='use which optimizer. SDG or Adam')
    parser.add_argument('-momentum',
                        default=0.9,
                        type=float,
                        help='momentum for SGD')
    parser.add_argument('-lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('-channels',
                        default=128,
                        type=int,
                        help='channels of CSNN')
    parser.add_argument(
        '-save-es',
        default=None,
        help=
        'dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}'
    )

    args = parser.parse_args()
    print(args)

    net = CSNN(T=args.T, channels=args.channels, use_cupy=args.cupy)

    print(net)

    net.to(args.device)

    train_set = torchvision.datasets.FashionMNIST(
        root=args.data_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)

    test_set = torchvision.datasets.FashionMNIST(
        root=args.data_dir,
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True)

    train_data_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                    batch_size=args.b,
                                                    shuffle=True,
                                                    drop_last=True,
                                                    num_workers=args.j,
                                                    pin_memory=True)

    test_data_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                   batch_size=args.b,
                                                   shuffle=True,
                                                   drop_last=False,
                                                   num_workers=args.j,
                                                   pin_memory=True)

    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    start_epoch = 0
    max_test_acc = -1

    optimizer = None
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(args.opt)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epochs)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']
        if args.save_es is not None and args.save_es != '':
            encoder = net.spiking_encoder()
            with torch.no_grad():
                for img, label in test_data_loader:
                    img = img.to(args.device)
                    label = label.to(args.device)
                    # img.shape = [N, C, H, W]
                    img_seq = img.unsqueeze(0).repeat(
                        net.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
                    spike_seq = encoder(img_seq)
                    functional.reset_net(encoder)
                    to_pil_img = torchvision.transforms.ToPILImage()
                    vs_dir = os.path.join(args.save_es, 'visualization')
                    os.mkdir(vs_dir)

                    img = img.cpu()
                    spike_seq = spike_seq.cpu()

                    img = F.interpolate(img, scale_factor=4, mode='bilinear')
                    # 28 * 28 is too small to read. So, we interpolate it to a larger size

                    for i in range(label.shape[0]):
                        vs_dir_i = os.path.join(vs_dir, f'{i}')
                        os.mkdir(vs_dir_i)
                        to_pil_img(img[i]).save(
                            os.path.join(vs_dir_i, f'input.png'))
                        for t in range(net.T):
                            print(f'saving {i}-th sample with t={t}...')
                            # spike_seq.shape = [T, N, C, H, W]

                            visualizing.plot_2d_feature_map(
                                spike_seq[t][i], 8, spike_seq.shape[2] // 8, 2,
                                f'$S[{t}]$')
                            plt.savefig(os.path.join(vs_dir_i, f's_{t}.png'),
                                        pad_inches=0.02)
                            plt.savefig(os.path.join(vs_dir_i, f's_{t}.pdf'),
                                        pad_inches=0.02)
                            plt.savefig(os.path.join(vs_dir_i, f's_{t}.svg'),
                                        pad_inches=0.02)
                            plt.clf()

                    exit()

    out_dir = os.path.join(
        args.out_dir,
        f'T{args.T}_b{args.b}_{args.opt}_lr{args.lr}_c{args.channels}')

    if args.amp:
        out_dir += '_amp'

    if args.cupy:
        out_dir += '_cupy'

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'Mkdir {out_dir}.')

    writer = SummaryWriter(out_dir, purge_step=start_epoch)
    with open(os.path.join(out_dir, 'args.txt'), 'w',
              encoding='utf-8') as args_txt:
        args_txt.write(str(args))
        args_txt.write('\n')
        args_txt.write(' '.join(sys.argv))

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for img, label in train_data_loader:
            optimizer.zero_grad()
            img = img.to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 10).float()

            if scaler is not None:
                with amp.autocast():
                    out_fr = net(img)
                    loss = F.mse_loss(out_fr, label_onehot)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                out_fr = net(img)
                loss = F.mse_loss(out_fr, label_onehot)
                loss.backward()
                optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)

        train_time = time.time()
        train_speed = train_samples / (train_time - start_time)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        lr_scheduler.step()

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for img, label in test_data_loader:
                img = img.to(args.device)
                label = label.to(args.device)
                label_onehot = F.one_hot(label, 10).float()
                out_fr = net(img)
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)
        test_time = time.time()
        test_speed = test_samples / (test_time - train_time)
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        print(args)
        print(out_dir)
        print(
            f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}'
        )
        print(
            f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s'
        )
        print(
            f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n'
        )
Exemplo n.º 6
0
def main():
    # python -m spikingjelly.activation_based.examples.classify_dvsg -T 16 -device cuda:0 -b 16 -epochs 64 -data-dir /datasets/DVSGesture/ -amp -cupy -opt adam -lr 0.001 -j 8

    parser = argparse.ArgumentParser(description='Classify DVS Gesture')
    parser.add_argument('-T',
                        default=16,
                        type=int,
                        help='simulating time-steps')
    parser.add_argument('-device', default='cuda:0', help='device')
    parser.add_argument('-b', default=16, type=int, help='batch size')
    parser.add_argument('-epochs',
                        default=64,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-data-dir',
                        type=str,
                        help='root dir of DVS Gesture dataset')
    parser.add_argument('-out-dir',
                        type=str,
                        default='./logs',
                        help='root dir for saving logs and checkpoint')
    parser.add_argument('-resume',
                        type=str,
                        help='resume from the checkpoint path')
    parser.add_argument('-amp',
                        action='store_true',
                        help='automatic mixed precision training')
    parser.add_argument('-cupy', action='store_true', help='use cupy backend')
    parser.add_argument('-opt',
                        type=str,
                        help='use which optimizer. SDG or Adam')
    parser.add_argument('-momentum',
                        default=0.9,
                        type=float,
                        help='momentum for SGD')
    parser.add_argument('-lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('-channels',
                        default=128,
                        type=int,
                        help='channels of CSNN')

    args = parser.parse_args()
    print(args)

    net = parametric_lif_net.DVSGestureNet(channels=args.channels,
                                           spiking_neuron=neuron.LIFNode,
                                           surrogate_function=surrogate.ATan(),
                                           detach_reset=True)

    functional.set_step_mode(net, 'm')
    if args.cupy:
        functional.set_backend(net, 'cupy', instance=neuron.LIFNode)

    print(net)

    net.to(args.device)

    train_set = DVS128Gesture(root=args.data_dir,
                              train=True,
                              data_type='frame',
                              frames_number=args.T,
                              split_by='number')
    test_set = DVS128Gesture(root=args.data_dir,
                             train=False,
                             data_type='frame',
                             frames_number=args.T,
                             split_by='number')

    train_data_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                    batch_size=args.b,
                                                    shuffle=True,
                                                    drop_last=True,
                                                    num_workers=args.j,
                                                    pin_memory=True)

    test_data_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                   batch_size=args.b,
                                                   shuffle=True,
                                                   drop_last=False,
                                                   num_workers=args.j,
                                                   pin_memory=True)

    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    start_epoch = 0
    max_test_acc = -1

    optimizer = None
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(args.opt)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epochs)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']

    out_dir = os.path.join(
        args.out_dir,
        f'T{args.T}_b{args.b}_{args.opt}_lr{args.lr}_c{args.channels}')

    if args.amp:
        out_dir += '_amp'

    if args.cupy:
        out_dir += '_cupy'

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'Mkdir {out_dir}.')

    writer = SummaryWriter(out_dir, purge_step=start_epoch)
    with open(os.path.join(out_dir, 'args.txt'), 'w',
              encoding='utf-8') as args_txt:
        args_txt.write(str(args))
        args_txt.write('\n')
        args_txt.write(' '.join(sys.argv))

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for frame, label in train_data_loader:
            optimizer.zero_grad()
            frame = frame.to(args.device)
            frame = frame.transpose(0, 1)  # [N, T, C, H, W] -> [T, N, C, H, W]
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 11).float()

            if scaler is not None:
                with amp.autocast():
                    out_fr = net(frame).mean(0)
                    loss = F.mse_loss(out_fr, label_onehot)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                out_fr = net(frame).mean(0)
                loss = F.mse_loss(out_fr, label_onehot)
                loss.backward()
                optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)

        train_time = time.time()
        train_speed = train_samples / (train_time - start_time)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        lr_scheduler.step()

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for frame, label in test_data_loader:
                frame = frame.to(args.device)
                frame = frame.transpose(
                    0, 1)  # [N, T, C, H, W] -> [T, N, C, H, W]
                label = label.to(args.device)
                label_onehot = F.one_hot(label, 11).float()
                out_fr = net(frame).mean(0)
                loss = F.mse_loss(out_fr, label_onehot)
                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)
        test_time = time.time()
        test_speed = test_samples / (test_time - train_time)
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        print(args)
        print(out_dir)
        print(
            f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}'
        )
        print(
            f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s'
        )
        print(
            f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n'
        )
Exemplo n.º 7
0
        for audios, labels in tqdm(train_dataloader):
            audios = audios.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()

            out_spikes_counter_frequency = net(audios)

            loss = criterion(out_spikes_counter_frequency, labels)
            loss.backward()

            # nn.utils.clip_grad_value_(net.parameters(), 5)

            optimizer.step()

            reset_net(net)

            # Rate-based output decoding
            correct_rate = (out_spikes_counter_frequency.argmax(
                dim=1) == labels).float().mean().item()

            net.train_times += 1

        if e >= warmup_epochs:
            lr_scheduler.step()

        net.eval()

        writer.add_scalar('Train Loss', loss.item(), global_step=net.epochs)

        ##### TEST #####
def main():
    # python -m spikingjelly.activation_based.examples.rsnn_sequential_fmnist -device cuda:0 -b 256 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt adam -lr 0.001 -j 8 -model plain
    parser = argparse.ArgumentParser(description='Classify Sequential Fashion-MNIST')
    parser.add_argument('-model', default='plain', type=str, help='use which model, "plain", "ss" (StatefulSynapseNet) or "fb" (FeedBackNet)')
    parser.add_argument('-device', default='cuda:0', help='device')
    parser.add_argument('-b', default=128, type=int, help='batch size')
    parser.add_argument('-epochs', default=64, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-data-dir', type=str, help='root dir of Fashion-MNIST dataset')
    parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint')
    parser.add_argument('-resume', type=str, help='resume from the checkpoint path')
    parser.add_argument('-amp', action='store_true', help='automatic mixed precision training')
    parser.add_argument('-cupy', action='store_true', help='use cupy backend')
    parser.add_argument('-opt', type=str, help='use which optimizer. SDG or Adam')
    parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
    parser.add_argument('-lr', default=0.1, type=float, help='learning rate')

    args = parser.parse_args()
    print(args)

    if args.model == 'plain':
        net = PlainNet()
    elif args.model == 'ss':
        net = StatefulSynapseNet()
    elif args.model == 'fb':
        net = FeedBackNet()

    net.to(args.device)

    # `functional.set_step_mode` will not set neurons in LinearRecurrentContainer to use step_mode = 'm'
    functional.set_step_mode(net, step_mode='m')

    if args.cupy:
        # neurons in LinearRecurrentContainer still use step_mode = 's', so, they will still use backend = 'torch'
        functional.set_backend(net, backend='cupy')

    print(net)

    train_set = torchvision.datasets.FashionMNIST(
            root=args.data_dir,
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True)

    test_set = torchvision.datasets.FashionMNIST(
            root=args.data_dir,
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True)

    train_data_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.b,
        shuffle=True,
        drop_last=True,
        num_workers=args.j,
        pin_memory=True
    )

    test_data_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=args.b,
        shuffle=True,
        drop_last=False,
        num_workers=args.j,
        pin_memory=True
    )


    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    start_epoch = 0
    max_test_acc = -1

    optimizer = None
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(args.opt)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']

    out_dir = os.path.join(args.out_dir, f'{args.model}_b{args.b}_{args.opt}_lr{args.lr}')

    if args.amp:
        out_dir += '_amp'

    if args.cupy:
        out_dir += '_cupy'

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'Mkdir {out_dir}.')

    writer = SummaryWriter(out_dir, purge_step=start_epoch)
    with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))
        args_txt.write('\n')
        args_txt.write(' '.join(sys.argv))

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for img, label in train_data_loader:
            optimizer.zero_grad()
            img = img.to(args.device)
            label = label.to(args.device)
            # img.shape = [N, 1, H, W]
            img.squeeze_(1)  # [N, H, W]
            img = img.permute(2, 0, 1)  # [W, N, H]
            # we regard [W, N, H] as [T, N, H]
            label_onehot = F.one_hot(label, 10).float()

            if scaler is not None:
                with amp.autocast():
                    out_fr = net(img)
                    loss = F.mse_loss(out_fr, label_onehot)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                out_fr = net(img)
                loss = F.mse_loss(out_fr, label_onehot)
                loss.backward()
                optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)

        train_time = time.time()
        train_speed = train_samples / (train_time - start_time)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        lr_scheduler.step()

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for img, label in test_data_loader:
                img = img.to(args.device)
                label = label.to(args.device)
                # img.shape = [N, 1, H, W]
                img.squeeze_(1)  # [N, H, W]
                img = img.permute(2, 0, 1)  # [W, N, H]
                # we regard [W, N, H] as [T, N, H]

                label_onehot = F.one_hot(label, 10).float()
                out_fr = net(img)
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)
        test_time = time.time()
        test_speed = test_samples / (test_time - train_time)
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        print(args)
        print(out_dir)
        print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}')
        print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')
        print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')
Exemplo n.º 9
0
    state = envs.reset()

    writer = SummaryWriter(logdir='./log')

    while step_idx < max_steps:

        log_probs = []
        values = []
        rewards = []
        masks = []
        entropy = 0

        for _ in range(num_steps):
            state = torch.FloatTensor(state).to(device)
            dist, value = model(state)
            functional.reset_net(model)

            action = dist.sample()
            next_state, reward, done, _ = envs.step(action.cpu().numpy())

            log_prob = dist.log_prob(action)
            entropy += dist.entropy().mean()

            log_probs.append(log_prob)
            values.append(value)
            rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))
            masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))

            state = next_state
            step_idx += 1
Exemplo n.º 10
0
def main():
    '''
    :return: None

    * :ref:`API in English <lif_fc_mnist.main-en>`

    .. _lif_fc_mnist.main-cn:

    使用全连接-LIF的网络结构,进行MNIST识别。\n
    这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。

    * :ref:`中文API <lif_fc_mnist.main-cn>`

    .. _lif_fc_mnist.main-en:

    The network with FC-LIF structure for classifying MNIST.\n
    This function initials the network, starts trainingand shows accuracy on test dataset.
    '''
    parser = argparse.ArgumentParser(description='LIF MNIST Training')
    parser.add_argument('-T',
                        default=100,
                        type=int,
                        help='simulating time-steps')
    parser.add_argument('-device', default='cuda:0', help='device')
    parser.add_argument('-b', default=64, type=int, help='batch size')
    parser.add_argument('-epochs',
                        default=100,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('-j',
                        default=4,
                        type=int,
                        metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('-data-dir',
                        type=str,
                        help='root dir of MNIST dataset')
    parser.add_argument('-out-dir',
                        type=str,
                        default='./logs',
                        help='root dir for saving logs and checkpoint')
    parser.add_argument('-resume',
                        type=str,
                        help='resume from the checkpoint path')
    parser.add_argument('-amp',
                        action='store_true',
                        help='automatic mixed precision training')
    parser.add_argument('-opt',
                        type=str,
                        choices=['sgd', 'adam'],
                        default='adam',
                        help='use which optimizer. SGD or Adam')
    parser.add_argument('-momentum',
                        default=0.9,
                        type=float,
                        help='momentum for SGD')
    parser.add_argument('-lr', default=1e-3, type=float, help='learning rate')
    parser.add_argument('-tau',
                        default=2.0,
                        type=float,
                        help='parameter tau of LIF neuron')

    args = parser.parse_args()
    print(args)

    net = SNN(tau=args.tau)

    print(net)

    net.to(args.device)

    # 初始化数据加载器
    train_dataset = torchvision.datasets.MNIST(
        root=args.data_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    test_dataset = torchvision.datasets.MNIST(
        root=args.data_dir,
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True)

    train_data_loader = data.DataLoader(dataset=train_dataset,
                                        batch_size=args.b,
                                        shuffle=True,
                                        drop_last=True,
                                        num_workers=args.j,
                                        pin_memory=True)
    test_data_loader = data.DataLoader(dataset=test_dataset,
                                       batch_size=args.b,
                                       shuffle=False,
                                       drop_last=False,
                                       num_workers=args.j,
                                       pin_memory=True)

    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    start_epoch = 0
    max_test_acc = -1

    optimizer = None
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(net.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(args.opt)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch'] + 1
        max_test_acc = checkpoint['max_test_acc']

    out_dir = os.path.join(args.out_dir,
                           f'T{args.T}_b{args.b}_{args.opt}_lr{args.lr}')

    if args.amp:
        out_dir += '_amp'

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'Mkdir {out_dir}.')

    with open(os.path.join(out_dir, 'args.txt'), 'w',
              encoding='utf-8') as args_txt:
        args_txt.write(str(args))

    writer = SummaryWriter(out_dir, purge_step=start_epoch)
    with open(os.path.join(out_dir, 'args.txt'), 'w',
              encoding='utf-8') as args_txt:
        args_txt.write(str(args))
        args_txt.write('\n')
        args_txt.write(' '.join(sys.argv))

    encoder = encoding.PoissonEncoder()

    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for img, label in train_data_loader:
            optimizer.zero_grad()
            img = img.to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 10).float()

            if scaler is not None:
                with amp.autocast():
                    out_fr = 0.
                    for t in range(args.T):
                        encoded_img = encoder(img)
                        out_fr += net(encoded_img)
                    out_fr = out_fr / args.T
                    loss = F.mse_loss(out_fr, label_onehot)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                out_fr = 0.
                for t in range(args.T):
                    encoded_img = encoder(img)
                    out_fr += net(encoded_img)
                out_fr = out_fr / args.T
                loss = F.mse_loss(out_fr, label_onehot)
                loss.backward()
                optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)

        train_time = time.time()
        train_speed = train_samples / (train_time - start_time)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for img, label in test_data_loader:
                img = img.to(args.device)
                label = label.to(args.device)
                label_onehot = F.one_hot(label, 10).float()
                out_fr = 0.
                for t in range(args.T):
                    encoded_img = encoder(img)
                    out_fr += net(encoded_img)
                out_fr = out_fr / args.T
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)
        test_time = time.time()
        test_speed = test_samples / (test_time - train_time)
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        print(args)
        print(out_dir)
        print(
            f'epoch ={epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}'
        )
        print(
            f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s'
        )
        print(
            f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n'
        )

    # 保存绘图用数据
    net.eval()
    # 注册钩子
    output_layer = net.layer[-1]  # 输出层
    output_layer.v_seq = []
    output_layer.s_seq = []

    def save_hook(m, x, y):
        m.v_seq.append(m.v.unsqueeze(0))
        m.s_seq.append(y.unsqueeze(0))

    output_layer.register_forward_hook(save_hook)

    with torch.no_grad():
        img, label = test_dataset[0]
        img = img.to(args.device)
        out_fr = 0.
        for t in range(args.T):
            encoded_img = encoder(img)
            out_fr += net(encoded_img)
        out_spikes_counter_frequency = (out_fr / args.T).cpu().numpy()
        print(f'Firing rate: {out_spikes_counter_frequency}')

        output_layer.v_seq = torch.cat(output_layer.v_seq)
        output_layer.s_seq = torch.cat(output_layer.s_seq)
        v_t_array = output_layer.v_seq.cpu().numpy().squeeze(
        )  # v_t_array[i][j]表示神经元i在j时刻的电压值
        np.save("v_t_array.npy", v_t_array)
        s_t_array = output_layer.s_seq.cpu().numpy().squeeze(
        )  # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1
        np.save("s_t_array.npy", s_t_array)
Exemplo n.º 11
0
def play(device, pt_path, hidden_num, played_frames=60, save_fig_num=0, fig_dir=None, figsize=(12, 6), firing_rates_plot_type='bar', heatmap_shape=None):
    import numpy as np
    from matplotlib import pyplot as plt
    import matplotlib.ticker
    plt.rcParams['figure.figsize'] = figsize
    # plt.rcParams['figure.dpi'] = 200
    plt.ion()
    env = gym.make('CartPole-v0').unwrapped

    policy_net = DQSN(hidden_num).to(device)
    policy_net.load_state_dict(torch.load(pt_path, map_location=device))

    env.reset()
    state = torch.zeros([1, 4], dtype=torch.float, device=device)

    with torch.no_grad():
        functional.set_monitor(policy_net, True)
        delta_lim = 0
        over_score = 1e9
        for i in count():
            # plt.clf()
            LIF_v = policy_net(state)  # shape=[1, 2]
            action = LIF_v.max(1)[1].view(1, 1).item()
            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (1, 0), colspan=3)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (1, 0))
            plt.xticks(np.arange(2), ('Left', 'Right'))
            plt.ylabel('Voltage')
            plt.title('Voltage of LIF neurons at last time step')
            delta_lim = (LIF_v.max() - LIF_v.min()) * 0.5
            plt.ylim(LIF_v.min() - delta_lim, LIF_v.max() + delta_lim)
            plt.yticks([])
            plt.text(0, LIF_v[0][0], str(round(LIF_v[0][0].item(), 2)), ha='center')
            plt.text(1, LIF_v[0][1], str(round(LIF_v[0][1].item(), 2)), ha='center')

            plt.bar(np.arange(2), LIF_v.squeeze(), color=['r', 'gray'] if action == 0 else ['gray', 'r'], width=0.5)
            if LIF_v.min() - delta_lim < 0:
                plt.axhline(0, color='black', linewidth=0.1)

            IF_spikes = np.asarray(policy_net.fc[1].monitor['s'])  # shape=[16, 1, 256]
            firing_rates = IF_spikes.mean(axis=0).squeeze()
            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (0, 4), rowspan=2, colspan=5)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (0, 1), rowspan=2, colspan=2)
            plt.title('Firing rates of IF neurons')

            if firing_rates_plot_type == 'bar':
                # 绘制柱状图
                plt.xlabel('Neuron index')
                plt.ylabel('Firing rate')
                plt.xlim(0, firing_rates.size)
                plt.ylim(0, 1.01)
                plt.bar(np.arange(firing_rates.size), firing_rates, width=0.5)

            elif firing_rates_plot_type == 'heatmap':
                # 绘制热力图
                heatmap = plt.imshow(firing_rates.reshape(heatmap_shape), vmin=0, vmax=1, cmap='ocean')
                plt.gca().invert_yaxis()
                cbar = heatmap.figure.colorbar(heatmap)
                cbar.ax.set_ylabel('Magnitude', rotation=90, va='top')
            functional.reset_net(policy_net)
            subtitle = f'Position={state[0][0].item(): .2f}, Velocity={state[0][1].item(): .2f}, Pole Angle={state[0][2].item(): .2f}, Pole Velocity At Tip={state[0][3].item(): .2f}, Score={i}'

            state, reward, done, _ = env.step(action)
            if done:
                over_score = min(over_score, i)
                subtitle = f'Game over, Score={over_score}'
            plt.suptitle(subtitle)

            state = torch.from_numpy(state).float().to(device).unsqueeze(0)
            screen = env.render(mode='rgb_array').copy()
            screen[300, :, :] = 0  # 画出黑线
            if firing_rates_plot_type == 'bar':
                plt.subplot2grid((2, 9), (0, 0), colspan=3)
            elif firing_rates_plot_type == 'heatmap':
                plt.subplot2grid((2, 3), (0, 0))
            plt.xticks([])
            plt.yticks([])
            plt.title('Game screen')
            plt.imshow(screen, interpolation='bicubic')
            plt.pause(0.001)
            if i < save_fig_num:
                plt.savefig(os.path.join(fig_dir, f'{i}.png'))
            if done and i >= played_frames:
                env.close()
                plt.close()
                break