def ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantages, clip_param=0.2): for _ in range(ppo_epochs): for state, action, old_log_probs, return_, advantage in ppo_iter( mini_batch_size, states, actions, log_probs, returns, advantages): dist, value = model(state) functional.reset_net(model) entropy = dist.entropy().mean() new_log_probs = dist.log_prob(action) ratio = (new_log_probs - old_log_probs).exp() surr1 = ratio * advantage surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage actor_loss = -torch.min(surr1, surr2).mean() critic_loss = (return_ - value).pow(2).mean() loss = 0.5 * critic_loss + actor_loss - 0.001 * entropy optimizer.zero_grad() loss.backward() optimizer.step()
def optimize_model(): if len(memory) < BATCH_SIZE: return transitions = memory.sample(BATCH_SIZE) batch = Transition(*zip(*transitions)) non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) state_action_values = policy_net(state_batch).gather(1, action_batch) next_state_values = torch.zeros(BATCH_SIZE, device=device) next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach() functional.reset_net(target_net) expected_state_action_values = (next_state_values * GAMMA) + reward_batch loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() for param in policy_net.parameters(): if param.grad is not None: param.grad.data.clamp_(-1, 1) optimizer.step() functional.reset_net(policy_net)
def select_action(state, steps_done): sample = random.random() eps_threshold = EPS_END + (EPS_START - EPS_END) * \ math.exp(-1. * steps_done / EPS_DECAY) if sample > eps_threshold: with torch.no_grad(): ac = policy_net(state).max(1)[1].view(1, 1) functional.reset_net(policy_net) return ac else: return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long)
def test_env(vis=False): state = env.reset() if vis: env.render() done = False total_reward = 0 while not done: state = torch.FloatTensor(state).unsqueeze(0).to(device) dist, _ = model(state) functional.reset_net(model) next_state, reward, done, _ = env.step( dist.sample().cpu().numpy()[0]) state = next_state if vis: env.render() total_reward += reward return total_reward
def main(): ''' (sj-dev) wfang@Precision-5820-Tower-X-Series:~/spikingjelly_dev$ python -m spikingjelly.activation_based.examples.conv_fashion_mnist -h usage: conv_fashion_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR] [-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR] Classify Fashion-MNIST optional arguments: -h, --help show this help message and exit -T T simulating time-steps -device DEVICE device -b B batch size -epochs N number of total epochs to run -j N number of data loading workers (default: 4) -data-dir DATA_DIR root dir of Fashion-MNIST dataset -out-dir OUT_DIR root dir for saving logs and checkpoint -resume RESUME resume from the checkpoint path -amp automatic mixed precision training -cupy use cupy neuron and multi-step forward mode -opt OPT use which optimizer. SDG or Adam -momentum MOMENTUM momentum for SGD -save-es dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode} ''' # python -m spikingjelly.activation_based.examples.conv_fashion_mnist -T 4 -device cuda:0 -b 128 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 # python -m spikingjelly.activation_based.examples.conv_fashion_mnist -T 4 -device cuda:0 -b 4 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -resume ./logs/T4_b256_sgd_lr0.1_c128_amp_cupy/checkpoint_latest.pth -save-es ./logs parser = argparse.ArgumentParser(description='Classify Fashion-MNIST') parser.add_argument('-T', default=4, type=int, help='simulating time-steps') parser.add_argument('-device', default='cuda:0', help='device') parser.add_argument('-b', default=128, type=int, help='batch size') parser.add_argument('-epochs', default=64, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-j', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('-data-dir', type=str, help='root dir of Fashion-MNIST dataset') parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint') parser.add_argument('-resume', type=str, help='resume from the checkpoint path') parser.add_argument('-amp', action='store_true', help='automatic mixed precision training') parser.add_argument('-cupy', action='store_true', help='use cupy backend') parser.add_argument('-opt', type=str, help='use which optimizer. SDG or Adam') parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD') parser.add_argument('-lr', default=0.1, type=float, help='learning rate') parser.add_argument('-channels', default=128, type=int, help='channels of CSNN') parser.add_argument( '-save-es', default=None, help= 'dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}' ) args = parser.parse_args() print(args) net = CSNN(T=args.T, channels=args.channels, use_cupy=args.cupy) print(net) net.to(args.device) train_set = torchvision.datasets.FashionMNIST( root=args.data_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True) test_set = torchvision.datasets.FashionMNIST( root=args.data_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True) train_data_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.b, shuffle=True, drop_last=True, num_workers=args.j, pin_memory=True) test_data_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=args.b, shuffle=True, drop_last=False, num_workers=args.j, pin_memory=True) scaler = None if args.amp: scaler = amp.GradScaler() start_epoch = 0 max_test_acc = -1 optimizer = None if args.opt == 'sgd': optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum) elif args.opt == 'adam': optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) else: raise NotImplementedError(args.opt) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.epochs) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) start_epoch = checkpoint['epoch'] + 1 max_test_acc = checkpoint['max_test_acc'] if args.save_es is not None and args.save_es != '': encoder = net.spiking_encoder() with torch.no_grad(): for img, label in test_data_loader: img = img.to(args.device) label = label.to(args.device) # img.shape = [N, C, H, W] img_seq = img.unsqueeze(0).repeat( net.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W] spike_seq = encoder(img_seq) functional.reset_net(encoder) to_pil_img = torchvision.transforms.ToPILImage() vs_dir = os.path.join(args.save_es, 'visualization') os.mkdir(vs_dir) img = img.cpu() spike_seq = spike_seq.cpu() img = F.interpolate(img, scale_factor=4, mode='bilinear') # 28 * 28 is too small to read. So, we interpolate it to a larger size for i in range(label.shape[0]): vs_dir_i = os.path.join(vs_dir, f'{i}') os.mkdir(vs_dir_i) to_pil_img(img[i]).save( os.path.join(vs_dir_i, f'input.png')) for t in range(net.T): print(f'saving {i}-th sample with t={t}...') # spike_seq.shape = [T, N, C, H, W] visualizing.plot_2d_feature_map( spike_seq[t][i], 8, spike_seq.shape[2] // 8, 2, f'$S[{t}]$') plt.savefig(os.path.join(vs_dir_i, f's_{t}.png'), pad_inches=0.02) plt.savefig(os.path.join(vs_dir_i, f's_{t}.pdf'), pad_inches=0.02) plt.savefig(os.path.join(vs_dir_i, f's_{t}.svg'), pad_inches=0.02) plt.clf() exit() out_dir = os.path.join( args.out_dir, f'T{args.T}_b{args.b}_{args.opt}_lr{args.lr}_c{args.channels}') if args.amp: out_dir += '_amp' if args.cupy: out_dir += '_cupy' if not os.path.exists(out_dir): os.makedirs(out_dir) print(f'Mkdir {out_dir}.') writer = SummaryWriter(out_dir, purge_step=start_epoch) with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: args_txt.write(str(args)) args_txt.write('\n') args_txt.write(' '.join(sys.argv)) for epoch in range(start_epoch, args.epochs): start_time = time.time() net.train() train_loss = 0 train_acc = 0 train_samples = 0 for img, label in train_data_loader: optimizer.zero_grad() img = img.to(args.device) label = label.to(args.device) label_onehot = F.one_hot(label, 10).float() if scaler is not None: with amp.autocast(): out_fr = net(img) loss = F.mse_loss(out_fr, label_onehot) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: out_fr = net(img) loss = F.mse_loss(out_fr, label_onehot) loss.backward() optimizer.step() train_samples += label.numel() train_loss += loss.item() * label.numel() train_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) train_time = time.time() train_speed = train_samples / (train_time - start_time) train_loss /= train_samples train_acc /= train_samples writer.add_scalar('train_loss', train_loss, epoch) writer.add_scalar('train_acc', train_acc, epoch) lr_scheduler.step() net.eval() test_loss = 0 test_acc = 0 test_samples = 0 with torch.no_grad(): for img, label in test_data_loader: img = img.to(args.device) label = label.to(args.device) label_onehot = F.one_hot(label, 10).float() out_fr = net(img) loss = F.mse_loss(out_fr, label_onehot) test_samples += label.numel() test_loss += loss.item() * label.numel() test_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) test_time = time.time() test_speed = test_samples / (test_time - train_time) test_loss /= test_samples test_acc /= test_samples writer.add_scalar('test_loss', test_loss, epoch) writer.add_scalar('test_acc', test_acc, epoch) save_max = False if test_acc > max_test_acc: max_test_acc = test_acc save_max = True checkpoint = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'max_test_acc': max_test_acc } if save_max: torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth')) torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth')) print(args) print(out_dir) print( f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}' ) print( f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s' ) print( f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n' )
def main(): # python -m spikingjelly.activation_based.examples.classify_dvsg -T 16 -device cuda:0 -b 16 -epochs 64 -data-dir /datasets/DVSGesture/ -amp -cupy -opt adam -lr 0.001 -j 8 parser = argparse.ArgumentParser(description='Classify DVS Gesture') parser.add_argument('-T', default=16, type=int, help='simulating time-steps') parser.add_argument('-device', default='cuda:0', help='device') parser.add_argument('-b', default=16, type=int, help='batch size') parser.add_argument('-epochs', default=64, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-j', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('-data-dir', type=str, help='root dir of DVS Gesture dataset') parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint') parser.add_argument('-resume', type=str, help='resume from the checkpoint path') parser.add_argument('-amp', action='store_true', help='automatic mixed precision training') parser.add_argument('-cupy', action='store_true', help='use cupy backend') parser.add_argument('-opt', type=str, help='use which optimizer. SDG or Adam') parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD') parser.add_argument('-lr', default=0.1, type=float, help='learning rate') parser.add_argument('-channels', default=128, type=int, help='channels of CSNN') args = parser.parse_args() print(args) net = parametric_lif_net.DVSGestureNet(channels=args.channels, spiking_neuron=neuron.LIFNode, surrogate_function=surrogate.ATan(), detach_reset=True) functional.set_step_mode(net, 'm') if args.cupy: functional.set_backend(net, 'cupy', instance=neuron.LIFNode) print(net) net.to(args.device) train_set = DVS128Gesture(root=args.data_dir, train=True, data_type='frame', frames_number=args.T, split_by='number') test_set = DVS128Gesture(root=args.data_dir, train=False, data_type='frame', frames_number=args.T, split_by='number') train_data_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=args.b, shuffle=True, drop_last=True, num_workers=args.j, pin_memory=True) test_data_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=args.b, shuffle=True, drop_last=False, num_workers=args.j, pin_memory=True) scaler = None if args.amp: scaler = amp.GradScaler() start_epoch = 0 max_test_acc = -1 optimizer = None if args.opt == 'sgd': optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum) elif args.opt == 'adam': optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) else: raise NotImplementedError(args.opt) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.epochs) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) start_epoch = checkpoint['epoch'] + 1 max_test_acc = checkpoint['max_test_acc'] out_dir = os.path.join( args.out_dir, f'T{args.T}_b{args.b}_{args.opt}_lr{args.lr}_c{args.channels}') if args.amp: out_dir += '_amp' if args.cupy: out_dir += '_cupy' if not os.path.exists(out_dir): os.makedirs(out_dir) print(f'Mkdir {out_dir}.') writer = SummaryWriter(out_dir, purge_step=start_epoch) with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: args_txt.write(str(args)) args_txt.write('\n') args_txt.write(' '.join(sys.argv)) for epoch in range(start_epoch, args.epochs): start_time = time.time() net.train() train_loss = 0 train_acc = 0 train_samples = 0 for frame, label in train_data_loader: optimizer.zero_grad() frame = frame.to(args.device) frame = frame.transpose(0, 1) # [N, T, C, H, W] -> [T, N, C, H, W] label = label.to(args.device) label_onehot = F.one_hot(label, 11).float() if scaler is not None: with amp.autocast(): out_fr = net(frame).mean(0) loss = F.mse_loss(out_fr, label_onehot) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: out_fr = net(frame).mean(0) loss = F.mse_loss(out_fr, label_onehot) loss.backward() optimizer.step() train_samples += label.numel() train_loss += loss.item() * label.numel() train_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) train_time = time.time() train_speed = train_samples / (train_time - start_time) train_loss /= train_samples train_acc /= train_samples writer.add_scalar('train_loss', train_loss, epoch) writer.add_scalar('train_acc', train_acc, epoch) lr_scheduler.step() net.eval() test_loss = 0 test_acc = 0 test_samples = 0 with torch.no_grad(): for frame, label in test_data_loader: frame = frame.to(args.device) frame = frame.transpose( 0, 1) # [N, T, C, H, W] -> [T, N, C, H, W] label = label.to(args.device) label_onehot = F.one_hot(label, 11).float() out_fr = net(frame).mean(0) loss = F.mse_loss(out_fr, label_onehot) test_samples += label.numel() test_loss += loss.item() * label.numel() test_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) test_time = time.time() test_speed = test_samples / (test_time - train_time) test_loss /= test_samples test_acc /= test_samples writer.add_scalar('test_loss', test_loss, epoch) writer.add_scalar('test_acc', test_acc, epoch) save_max = False if test_acc > max_test_acc: max_test_acc = test_acc save_max = True checkpoint = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'max_test_acc': max_test_acc } if save_max: torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth')) torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth')) print(args) print(out_dir) print( f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}' ) print( f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s' ) print( f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n' )
for audios, labels in tqdm(train_dataloader): audios = audios.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad() out_spikes_counter_frequency = net(audios) loss = criterion(out_spikes_counter_frequency, labels) loss.backward() # nn.utils.clip_grad_value_(net.parameters(), 5) optimizer.step() reset_net(net) # Rate-based output decoding correct_rate = (out_spikes_counter_frequency.argmax( dim=1) == labels).float().mean().item() net.train_times += 1 if e >= warmup_epochs: lr_scheduler.step() net.eval() writer.add_scalar('Train Loss', loss.item(), global_step=net.epochs) ##### TEST #####
def main(): # python -m spikingjelly.activation_based.examples.rsnn_sequential_fmnist -device cuda:0 -b 256 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt adam -lr 0.001 -j 8 -model plain parser = argparse.ArgumentParser(description='Classify Sequential Fashion-MNIST') parser.add_argument('-model', default='plain', type=str, help='use which model, "plain", "ss" (StatefulSynapseNet) or "fb" (FeedBackNet)') parser.add_argument('-device', default='cuda:0', help='device') parser.add_argument('-b', default=128, type=int, help='batch size') parser.add_argument('-epochs', default=64, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-j', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('-data-dir', type=str, help='root dir of Fashion-MNIST dataset') parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint') parser.add_argument('-resume', type=str, help='resume from the checkpoint path') parser.add_argument('-amp', action='store_true', help='automatic mixed precision training') parser.add_argument('-cupy', action='store_true', help='use cupy backend') parser.add_argument('-opt', type=str, help='use which optimizer. SDG or Adam') parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD') parser.add_argument('-lr', default=0.1, type=float, help='learning rate') args = parser.parse_args() print(args) if args.model == 'plain': net = PlainNet() elif args.model == 'ss': net = StatefulSynapseNet() elif args.model == 'fb': net = FeedBackNet() net.to(args.device) # `functional.set_step_mode` will not set neurons in LinearRecurrentContainer to use step_mode = 'm' functional.set_step_mode(net, step_mode='m') if args.cupy: # neurons in LinearRecurrentContainer still use step_mode = 's', so, they will still use backend = 'torch' functional.set_backend(net, backend='cupy') print(net) train_set = torchvision.datasets.FashionMNIST( root=args.data_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True) test_set = torchvision.datasets.FashionMNIST( root=args.data_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True) train_data_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=args.b, shuffle=True, drop_last=True, num_workers=args.j, pin_memory=True ) test_data_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=args.b, shuffle=True, drop_last=False, num_workers=args.j, pin_memory=True ) scaler = None if args.amp: scaler = amp.GradScaler() start_epoch = 0 max_test_acc = -1 optimizer = None if args.opt == 'sgd': optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum) elif args.opt == 'adam': optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) else: raise NotImplementedError(args.opt) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) start_epoch = checkpoint['epoch'] + 1 max_test_acc = checkpoint['max_test_acc'] out_dir = os.path.join(args.out_dir, f'{args.model}_b{args.b}_{args.opt}_lr{args.lr}') if args.amp: out_dir += '_amp' if args.cupy: out_dir += '_cupy' if not os.path.exists(out_dir): os.makedirs(out_dir) print(f'Mkdir {out_dir}.') writer = SummaryWriter(out_dir, purge_step=start_epoch) with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: args_txt.write(str(args)) args_txt.write('\n') args_txt.write(' '.join(sys.argv)) for epoch in range(start_epoch, args.epochs): start_time = time.time() net.train() train_loss = 0 train_acc = 0 train_samples = 0 for img, label in train_data_loader: optimizer.zero_grad() img = img.to(args.device) label = label.to(args.device) # img.shape = [N, 1, H, W] img.squeeze_(1) # [N, H, W] img = img.permute(2, 0, 1) # [W, N, H] # we regard [W, N, H] as [T, N, H] label_onehot = F.one_hot(label, 10).float() if scaler is not None: with amp.autocast(): out_fr = net(img) loss = F.mse_loss(out_fr, label_onehot) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: out_fr = net(img) loss = F.mse_loss(out_fr, label_onehot) loss.backward() optimizer.step() train_samples += label.numel() train_loss += loss.item() * label.numel() train_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) train_time = time.time() train_speed = train_samples / (train_time - start_time) train_loss /= train_samples train_acc /= train_samples writer.add_scalar('train_loss', train_loss, epoch) writer.add_scalar('train_acc', train_acc, epoch) lr_scheduler.step() net.eval() test_loss = 0 test_acc = 0 test_samples = 0 with torch.no_grad(): for img, label in test_data_loader: img = img.to(args.device) label = label.to(args.device) # img.shape = [N, 1, H, W] img.squeeze_(1) # [N, H, W] img = img.permute(2, 0, 1) # [W, N, H] # we regard [W, N, H] as [T, N, H] label_onehot = F.one_hot(label, 10).float() out_fr = net(img) loss = F.mse_loss(out_fr, label_onehot) test_samples += label.numel() test_loss += loss.item() * label.numel() test_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) test_time = time.time() test_speed = test_samples / (test_time - train_time) test_loss /= test_samples test_acc /= test_samples writer.add_scalar('test_loss', test_loss, epoch) writer.add_scalar('test_acc', test_acc, epoch) save_max = False if test_acc > max_test_acc: max_test_acc = test_acc save_max = True checkpoint = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'max_test_acc': max_test_acc } if save_max: torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth')) torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth')) print(args) print(out_dir) print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}') print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s') print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')
state = envs.reset() writer = SummaryWriter(logdir='./log') while step_idx < max_steps: log_probs = [] values = [] rewards = [] masks = [] entropy = 0 for _ in range(num_steps): state = torch.FloatTensor(state).to(device) dist, value = model(state) functional.reset_net(model) action = dist.sample() next_state, reward, done, _ = envs.step(action.cpu().numpy()) log_prob = dist.log_prob(action) entropy += dist.entropy().mean() log_probs.append(log_prob) values.append(value) rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device)) masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device)) state = next_state step_idx += 1
def main(): ''' :return: None * :ref:`API in English <lif_fc_mnist.main-en>` .. _lif_fc_mnist.main-cn: 使用全连接-LIF的网络结构,进行MNIST识别。\n 这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。 * :ref:`中文API <lif_fc_mnist.main-cn>` .. _lif_fc_mnist.main-en: The network with FC-LIF structure for classifying MNIST.\n This function initials the network, starts trainingand shows accuracy on test dataset. ''' parser = argparse.ArgumentParser(description='LIF MNIST Training') parser.add_argument('-T', default=100, type=int, help='simulating time-steps') parser.add_argument('-device', default='cuda:0', help='device') parser.add_argument('-b', default=64, type=int, help='batch size') parser.add_argument('-epochs', default=100, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-j', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('-data-dir', type=str, help='root dir of MNIST dataset') parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint') parser.add_argument('-resume', type=str, help='resume from the checkpoint path') parser.add_argument('-amp', action='store_true', help='automatic mixed precision training') parser.add_argument('-opt', type=str, choices=['sgd', 'adam'], default='adam', help='use which optimizer. SGD or Adam') parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD') parser.add_argument('-lr', default=1e-3, type=float, help='learning rate') parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron') args = parser.parse_args() print(args) net = SNN(tau=args.tau) print(net) net.to(args.device) # 初始化数据加载器 train_dataset = torchvision.datasets.MNIST( root=args.data_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True) test_dataset = torchvision.datasets.MNIST( root=args.data_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True) train_data_loader = data.DataLoader(dataset=train_dataset, batch_size=args.b, shuffle=True, drop_last=True, num_workers=args.j, pin_memory=True) test_data_loader = data.DataLoader(dataset=test_dataset, batch_size=args.b, shuffle=False, drop_last=False, num_workers=args.j, pin_memory=True) scaler = None if args.amp: scaler = amp.GradScaler() start_epoch = 0 max_test_acc = -1 optimizer = None if args.opt == 'sgd': optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum) elif args.opt == 'adam': optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) else: raise NotImplementedError(args.opt) if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') net.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] + 1 max_test_acc = checkpoint['max_test_acc'] out_dir = os.path.join(args.out_dir, f'T{args.T}_b{args.b}_{args.opt}_lr{args.lr}') if args.amp: out_dir += '_amp' if not os.path.exists(out_dir): os.makedirs(out_dir) print(f'Mkdir {out_dir}.') with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: args_txt.write(str(args)) writer = SummaryWriter(out_dir, purge_step=start_epoch) with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: args_txt.write(str(args)) args_txt.write('\n') args_txt.write(' '.join(sys.argv)) encoder = encoding.PoissonEncoder() for epoch in range(start_epoch, args.epochs): start_time = time.time() net.train() train_loss = 0 train_acc = 0 train_samples = 0 for img, label in train_data_loader: optimizer.zero_grad() img = img.to(args.device) label = label.to(args.device) label_onehot = F.one_hot(label, 10).float() if scaler is not None: with amp.autocast(): out_fr = 0. for t in range(args.T): encoded_img = encoder(img) out_fr += net(encoded_img) out_fr = out_fr / args.T loss = F.mse_loss(out_fr, label_onehot) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: out_fr = 0. for t in range(args.T): encoded_img = encoder(img) out_fr += net(encoded_img) out_fr = out_fr / args.T loss = F.mse_loss(out_fr, label_onehot) loss.backward() optimizer.step() train_samples += label.numel() train_loss += loss.item() * label.numel() train_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) train_time = time.time() train_speed = train_samples / (train_time - start_time) train_loss /= train_samples train_acc /= train_samples writer.add_scalar('train_loss', train_loss, epoch) writer.add_scalar('train_acc', train_acc, epoch) net.eval() test_loss = 0 test_acc = 0 test_samples = 0 with torch.no_grad(): for img, label in test_data_loader: img = img.to(args.device) label = label.to(args.device) label_onehot = F.one_hot(label, 10).float() out_fr = 0. for t in range(args.T): encoded_img = encoder(img) out_fr += net(encoded_img) out_fr = out_fr / args.T loss = F.mse_loss(out_fr, label_onehot) test_samples += label.numel() test_loss += loss.item() * label.numel() test_acc += (out_fr.argmax(1) == label).float().sum().item() functional.reset_net(net) test_time = time.time() test_speed = test_samples / (test_time - train_time) test_loss /= test_samples test_acc /= test_samples writer.add_scalar('test_loss', test_loss, epoch) writer.add_scalar('test_acc', test_acc, epoch) save_max = False if test_acc > max_test_acc: max_test_acc = test_acc save_max = True checkpoint = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'max_test_acc': max_test_acc } if save_max: torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth')) torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth')) print(args) print(out_dir) print( f'epoch ={epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}' ) print( f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s' ) print( f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n' ) # 保存绘图用数据 net.eval() # 注册钩子 output_layer = net.layer[-1] # 输出层 output_layer.v_seq = [] output_layer.s_seq = [] def save_hook(m, x, y): m.v_seq.append(m.v.unsqueeze(0)) m.s_seq.append(y.unsqueeze(0)) output_layer.register_forward_hook(save_hook) with torch.no_grad(): img, label = test_dataset[0] img = img.to(args.device) out_fr = 0. for t in range(args.T): encoded_img = encoder(img) out_fr += net(encoded_img) out_spikes_counter_frequency = (out_fr / args.T).cpu().numpy() print(f'Firing rate: {out_spikes_counter_frequency}') output_layer.v_seq = torch.cat(output_layer.v_seq) output_layer.s_seq = torch.cat(output_layer.s_seq) v_t_array = output_layer.v_seq.cpu().numpy().squeeze( ) # v_t_array[i][j]表示神经元i在j时刻的电压值 np.save("v_t_array.npy", v_t_array) s_t_array = output_layer.s_seq.cpu().numpy().squeeze( ) # s_t_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1 np.save("s_t_array.npy", s_t_array)
def play(device, pt_path, hidden_num, played_frames=60, save_fig_num=0, fig_dir=None, figsize=(12, 6), firing_rates_plot_type='bar', heatmap_shape=None): import numpy as np from matplotlib import pyplot as plt import matplotlib.ticker plt.rcParams['figure.figsize'] = figsize # plt.rcParams['figure.dpi'] = 200 plt.ion() env = gym.make('CartPole-v0').unwrapped policy_net = DQSN(hidden_num).to(device) policy_net.load_state_dict(torch.load(pt_path, map_location=device)) env.reset() state = torch.zeros([1, 4], dtype=torch.float, device=device) with torch.no_grad(): functional.set_monitor(policy_net, True) delta_lim = 0 over_score = 1e9 for i in count(): # plt.clf() LIF_v = policy_net(state) # shape=[1, 2] action = LIF_v.max(1)[1].view(1, 1).item() if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (1, 0), colspan=3) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (1, 0)) plt.xticks(np.arange(2), ('Left', 'Right')) plt.ylabel('Voltage') plt.title('Voltage of LIF neurons at last time step') delta_lim = (LIF_v.max() - LIF_v.min()) * 0.5 plt.ylim(LIF_v.min() - delta_lim, LIF_v.max() + delta_lim) plt.yticks([]) plt.text(0, LIF_v[0][0], str(round(LIF_v[0][0].item(), 2)), ha='center') plt.text(1, LIF_v[0][1], str(round(LIF_v[0][1].item(), 2)), ha='center') plt.bar(np.arange(2), LIF_v.squeeze(), color=['r', 'gray'] if action == 0 else ['gray', 'r'], width=0.5) if LIF_v.min() - delta_lim < 0: plt.axhline(0, color='black', linewidth=0.1) IF_spikes = np.asarray(policy_net.fc[1].monitor['s']) # shape=[16, 1, 256] firing_rates = IF_spikes.mean(axis=0).squeeze() if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (0, 4), rowspan=2, colspan=5) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (0, 1), rowspan=2, colspan=2) plt.title('Firing rates of IF neurons') if firing_rates_plot_type == 'bar': # 绘制柱状图 plt.xlabel('Neuron index') plt.ylabel('Firing rate') plt.xlim(0, firing_rates.size) plt.ylim(0, 1.01) plt.bar(np.arange(firing_rates.size), firing_rates, width=0.5) elif firing_rates_plot_type == 'heatmap': # 绘制热力图 heatmap = plt.imshow(firing_rates.reshape(heatmap_shape), vmin=0, vmax=1, cmap='ocean') plt.gca().invert_yaxis() cbar = heatmap.figure.colorbar(heatmap) cbar.ax.set_ylabel('Magnitude', rotation=90, va='top') functional.reset_net(policy_net) subtitle = f'Position={state[0][0].item(): .2f}, Velocity={state[0][1].item(): .2f}, Pole Angle={state[0][2].item(): .2f}, Pole Velocity At Tip={state[0][3].item(): .2f}, Score={i}' state, reward, done, _ = env.step(action) if done: over_score = min(over_score, i) subtitle = f'Game over, Score={over_score}' plt.suptitle(subtitle) state = torch.from_numpy(state).float().to(device).unsqueeze(0) screen = env.render(mode='rgb_array').copy() screen[300, :, :] = 0 # 画出黑线 if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (0, 0), colspan=3) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (0, 0)) plt.xticks([]) plt.yticks([]) plt.title('Game screen') plt.imshow(screen, interpolation='bicubic') plt.pause(0.001) if i < save_fig_num: plt.savefig(os.path.join(fig_dir, f'{i}.png')) if done and i >= played_frames: env.close() plt.close() break