def main(log_dir=None): torch.random.manual_seed(0) torch.cuda.manual_seed(0) z_norm_mean = (0.485, 0.456, 0.406) z_norm_std = (0.229, 0.224, 0.225) # example setting device = 'cuda:0' dataset_dir = 'Dataset/ILSVRC2012' batch_size = 64 learning_rate = 1e-4 T = 2000 train_epoch = 40 model_name = 'imagenetresnet50' load = False if log_dir == None: log_dir = './log-' + model_name + str(time.time()) if not os.path.exists(log_dir): os.makedirs(log_dir) else: if not os.path.exists(log_dir): os.makedirs(log_dir) print("All the temp files are saved to ", log_dir) ann_transform_train = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(z_norm_mean, z_norm_std), ]) ann_transform_test = transforms.Compose([ transforms.Resize(int(224 / 0.875)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(z_norm_mean, z_norm_std), ]) snn_transform = transforms.Compose([ transforms.Resize(int(224 / 0.875)), transforms.CenterCrop(224), transforms.ToTensor() ]) ann_train_data_dataset = torchvision.datasets.ImageFolder( root=os.path.join(dataset_dir, 'train'), transform=ann_transform_train) snn_train_data_dataset = torchvision.datasets.ImageFolder( root=os.path.join(dataset_dir, 'train'), transform=snn_transform) ann_train_data_loader = torch.utils.data.DataLoader( dataset=ann_train_data_dataset, batch_size=batch_size, shuffle=True, #num_workers=4, drop_last=True, pin_memory=True) ann_test_data_dataset = torchvision.datasets.ImageFolder( root=os.path.join(dataset_dir, 'val'), transform=ann_transform_test) snn_test_data_dataset = torchvision.datasets.ImageFolder( root=os.path.join(dataset_dir, 'val'), transform=snn_transform) ann_test_data_loader = torch.utils.data.DataLoader( dataset=ann_test_data_dataset, batch_size=batch_size, shuffle=False, #num_workers=4, drop_last=False, pin_memory=True) snn_test_data_loader = torch.utils.data.DataLoader( dataset=snn_test_data_dataset, batch_size=16, shuffle=False, #num_workers=4, drop_last=False, pin_memory=True) config = utils.Config.default_config print('ann2snn config:\n\t', config) utils.Config.store_config(os.path.join(log_dir, 'default_config.json'), config) loss_function = nn.CrossEntropyLoss() ann = resnet.resnet50().to(device) checkpoint_state_dict = torch.load( './model_lib/imagenet/checkpoint/ResNet50-state-dict.pth') ann.load_state_dict(checkpoint_state_dict) # writer = SummaryWriter(log_dir) print('Directly load model', model_name + '.pth') # 加载用于归一化模型的数据 # Load the data to normalize the model norm_set_len = int(len(snn_train_data_dataset.samples) / 500) print('Using %d pictures as norm set' % (norm_set_len)) norm_set_list = [] for idx, (datapath, target) in enumerate(snn_train_data_dataset.samples): norm_set_list.append(snn_transform(Image.open(datapath))) if idx == norm_set_len - 1: break norm_tensor = torch.stack(norm_set_list) ann_acc = utils.val_ann(net=ann, device=device, data_loader=ann_test_data_loader, loss_function=loss_function) # def hook(module,input,output): # print(module.__class__.__name__) # print(output.reshape(-1)[10:20]) # # handle = [] # for m in ann.modules(): # handle.append(m.register_forward_hook(hook)) #print(norm_tensor[10,:,:,:].shape) # z_score_layer = nn.BatchNorm2d(num_features=len(z_norm_std)) # norm_mean = torch.from_numpy(np.array(z_norm_mean).astype(np.float32)) # norm_std = torch.from_numpy(np.array(z_norm_std).astype(np.float32)) # z_score_layer.weight.data = torch.ones_like(z_score_layer.weight.data) # z_score_layer.bias.data = torch.zeros_like(z_score_layer.bias.data) # z_score_layer.running_var.data = torch.pow(norm_std, exponent=2) - z_score_layer.eps # z_score_layer.running_mean.data = norm_mean # z_score_layer.to('cuda:0') # z_score_layer.eval() # x = z_score_layer(torch.ones(1,3,224,224).to('cuda:0')) # print(x.reshape(-1)[10:20]) # ann.eval() # ann(x) # for h in handle: # h.remove() utils.onnx_ann2snn(model_name=model_name, ann=ann, norm_tensor=norm_tensor, loss_function=loss_function, test_data_loader=snn_test_data_loader, device=device, T=T, log_dir=log_dir, config=config, z_score=(z_norm_mean, z_norm_std))
def main(log_dir=None): torch.random.manual_seed(0) torch.cuda.manual_seed(0) train_device = input( '输入运行的设备,例如“cpu”或“cuda:0”\n input training device, e.g., "cpu" or "cuda:0": ' ) parser_device = input( '输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": ' ) simulator_device = parser_device # simulator_device = input('输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split(',') dataset_dir = input( '输入保存cifar10数据集的位置,例如“./”\n input root directory for saving cifar10 dataset, e.g., "./": ' ) batch_size = int( input('输入batch_size,例如“128”\n input batch_size, e.g., "128": ')) T = int(input('输入仿真时长,例如“400”\n input simulating steps, e.g., "400": ')) model_name = input( '输入模型名字,例如“resnet18_cifar10”\n input model name, for log_dir generating , e.g., "resnet18_cifar10": ' ) z_norm_mean = (0.4914, 0.4822, 0.4465) z_norm_std = (0.2023, 0.1994, 0.2010) load = False if log_dir == None: from datetime import datetime current_time = datetime.now().strftime('%b%d_%H-%M-%S') log_dir = model_name + '-' + current_time if not os.path.exists(log_dir): os.makedirs(log_dir) else: if not os.path.exists(log_dir): os.makedirs(log_dir) if not load: writer = SummaryWriter(log_dir) transform = torchvision.transforms.Compose( [torchvision.transforms.ToTensor()]) train_data_dataset = torchvision.datasets.CIFAR10(root=dataset_dir, train=True, transform=transform, download=True) train_data_loader = torch.utils.data.DataLoader(dataset=train_data_dataset, batch_size=batch_size, shuffle=True, drop_last=False) test_data_dataset = torchvision.datasets.CIFAR10(root=dataset_dir, train=False, transform=transform, download=True) test_data_loader = torch.utils.data.DataLoader(dataset=test_data_dataset, batch_size=batch_size, shuffle=True, drop_last=False) ann = resnet.ResNet18().to(train_device) loss_function = nn.CrossEntropyLoss() checkpoint_state_dict = torch.load( './SJ-cifar10-resnet18_model-sample.pth') ann.load_state_dict(checkpoint_state_dict) # 加载用于归一化模型的数据 # Load the data to normalize the model percentage = 0.004 # load 0.004 of the data norm_data_list = [] for idx, (imgs, targets) in enumerate(train_data_loader): norm_data_list.append(imgs) if idx == int(len(train_data_loader) * percentage) - 1: break norm_data = torch.cat(norm_data_list) print('use %d imgs to parse' % (norm_data.size(0))) onnxparser = parser(name=model_name, log_dir=log_dir + '/parser', kernel='onnx', z_norm=(z_norm_mean, z_norm_std)) snn = onnxparser.parse(ann, norm_data.to(parser_device)) ann_acc = utils.val_ann( torch.load(onnxparser.ann_filename).to(train_device), train_device, test_data_loader, loss_function) torch.save(snn, os.path.join(log_dir, 'snn-' + model_name + '.pkl')) fig = plt.figure('simulator') sim = classify_simulator(snn, log_dir=log_dir + '/simulator', device=simulator_device, canvas=fig) sim.simulate(test_data_loader, T=T, online_drawer=True, ann_acc=ann_acc, fig_name=model_name, step_max=True)
def main(log_dir=None): torch.random.manual_seed(0) torch.cuda.manual_seed(0) train_device = input( '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ') parser_device = input( '输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": ' ) simulator_device = parser_device # simulator_device = input( # '输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split( # ',') dataset_dir = input( '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving FashionMNIST dataset, e.g., "./": ' ) batch_size = int( input('输入batch_size,例如“128”\n input batch_size, e.g., "128": ')) learning_rate = float( input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": ')) T = int(input('输入仿真时长,例如“400”\n input simulating steps, e.g., "400": ')) train_epoch = int( input( '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": ')) model_name = input( '输入模型名字,例如“cnn_fashionmnist”\n input model name, for log_dir generating , e.g., "cnn_fashionmnist": ' ) load = False if log_dir == None: from datetime import datetime current_time = datetime.now().strftime('%b%d_%H-%M-%S') log_dir = model_name + '-' + current_time if not os.path.exists(log_dir): os.makedirs(log_dir) else: if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(os.path.join(log_dir, model_name + '.pkl')): print('%s has no model to load.' % (log_dir)) load = False else: load = True if not load: writer = SummaryWriter(log_dir) # 初始化数据加载器 # initialize data loader train_data_dataset = torchvision.datasets.FashionMNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True) train_data_loader = torch.utils.data.DataLoader(train_data_dataset, batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.FashionMNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=100, shuffle=True, drop_last=False) ann = ANN().to(train_device) loss_function = nn.CrossEntropyLoss() if not load: optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4) best_acc = 0.0 for epoch in range(train_epoch): # 使用utils中预先写好的训练程序训练网络 # 训练程序的写法和经典ANN中的训练也是一样的 # Train the network using a pre-prepared code in ''utils'' utils.train_ann(net=ann, device=train_device, data_loader=train_data_loader, optimizer=optimizer, loss_function=loss_function, epoch=epoch) # 使用utils中预先写好的验证程序验证网络输出 # Validate the network using a pre-prepared code in ''utils'' acc = utils.val_ann(net=ann, device=train_device, data_loader=test_data_loader, loss_function=loss_function, epoch=epoch) if best_acc <= acc: utils.save_model(ann, log_dir, model_name + '.pkl') writer.add_scalar('val_accuracy', acc, epoch) ann = torch.load(os.path.join(log_dir, model_name + '.pkl')) print('validating best model...') ann_acc = utils.val_ann(net=ann, device=train_device, data_loader=test_data_loader, loss_function=loss_function) # 加载用于归一化模型的数据 # Load the data to normalize the model percentage = 0.004 # load 0.004 of the data norm_data_list = [] for idx, (imgs, targets) in enumerate(train_data_loader): norm_data_list.append(imgs) if idx == int(len(train_data_loader) * percentage) - 1: break norm_data = torch.cat(norm_data_list) print('use %d imgs to parse' % (norm_data.size(0))) onnxparser = parser(name=model_name, log_dir=log_dir + '/parser', kernel='onnx') snn = onnxparser.parse(ann, norm_data.to(parser_device)) torch.save(snn, os.path.join(log_dir, 'snn-' + model_name + '.pkl')) fig = plt.figure('simulator') sim = classify_simulator(snn, log_dir=log_dir + '/simulator', device=simulator_device, canvas=fig) sim.simulate(test_data_loader, T=T, online_drawer=True, ann_acc=ann_acc, fig_name=model_name, step_max=True)
def main(log_dir=None): ''' :return: None 使用Conv-ReLU-[Conv-ReLU]-全连接-ReLU的网络结构训练并转换为SNN,进行MNIST识别。运行示例: .. code-block:: python >>> import spikingjelly.clock_driven.ann2snn.examples.cnn_mnist as cnn_mnist >>> cnn_mnist.main() 输入运行的设备,例如“cpu”或“cuda:0” input device, e.g., "cpu" or "cuda:0": cuda:15 输入保存MNIST数据集的位置,例如“./” input root directory for saving MNIST dataset, e.g., "./": ./mnist 输入batch_size,例如“64” input batch_size, e.g., "64": 128 输入学习率,例如“1e-3” input learning rate, e.g., "1e-3": 1e-3 输入仿真时长,例如“100” input simulating steps, e.g., "100": 100 输入训练轮数,即遍历训练集的次数,例如“10” input training epochs, e.g., "10": 10 输入模型名字,用于自动生成日志文档,例如“cnn_mnist” input model name, for log_dir generating , e.g., "cnn_mnist" Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078 Epoch 0 [101/937] ANN Training Loss:1.423 Accuracy:0.669 Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773 Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795 Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788 Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792 Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795 Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.835 Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880 Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.889 100%|██████████| 100/100 [00:00<00:00, 116.12it/s] Epoch 0 [100/100] ANN Validating Loss:0.327 Accuracy:0.881 Save model to: cnn_mnist-XXXXX\cnn_mnist.pkl ...... --------------------simulator summary-------------------- time elapsed: 46.55072790000008 (sec) --------------------------------------------------------- ''' torch.random.manual_seed(0) torch.cuda.manual_seed(0) train_device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ') parser_device = input('输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": ') simulator_device = parser_device # simulator_device = input( # '输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split( # ',') dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ') batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": ')) learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": ')) T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": ')) train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“10”\n input training epochs, e.g., "10": ')) model_name = input('输入模型名字,例如“cnn_mnist”\n input model name, for log_dir generating , e.g., "cnn_mnist": ') load = False if log_dir == None: from datetime import datetime current_time = datetime.now().strftime('%b%d_%H-%M-%S') log_dir = model_name+'-'+current_time if not os.path.exists(log_dir): os.makedirs(log_dir) else: if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(os.path.join(log_dir,model_name+'.pkl')): print('%d has no model to load.'%(log_dir)) load = False else: load = True if not load: writer = SummaryWriter(log_dir) # 初始化数据加载器 # initialize data loader train_data_dataset = torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True) train_data_loader = torch.utils.data.DataLoader( train_data_dataset, batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=100, shuffle=True, drop_last=False) ann = ANN().to(train_device) loss_function = nn.CrossEntropyLoss() if not load: optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4) best_acc = 0.0 for epoch in range(train_epoch): # 使用utils中预先写好的训练程序训练网络 # 训练程序的写法和经典ANN中的训练也是一样的 # Train the network using a pre-prepared code in ''utils'' utils.train_ann(net=ann, device=train_device, data_loader=train_data_loader, optimizer=optimizer, loss_function=loss_function, epoch=epoch ) # 使用utils中预先写好的验证程序验证网络输出 # Validate the network using a pre-prepared code in ''utils'' acc = utils.val_ann(net=ann, device=train_device, data_loader=test_data_loader, loss_function=loss_function, epoch=epoch ) if best_acc <= acc: utils.save_model(ann, log_dir, model_name + '.pkl') writer.add_scalar('val_accuracy', acc, epoch) ann = torch.load(os.path.join(log_dir, model_name + '.pkl')) print('validating best model...') ann_acc = utils.val_ann(net=ann, device=train_device, data_loader=test_data_loader, loss_function=loss_function ) # 加载用于归一化模型的数据 # Load the data to normalize the model percentage = 0.004 # load 0.004 of the data norm_data_list = [] for idx, (imgs, targets) in enumerate(train_data_loader): norm_data_list.append(imgs) if idx == int(len(train_data_loader) * percentage) - 1: break norm_data = torch.cat(norm_data_list) print('use %d imgs to parse' % (norm_data.size(0))) # 调用parser,使用kernel为onnx # Call parser, use onnx kernel onnxparser = parser(name=model_name, log_dir=log_dir + '/parser', kernel='onnx') snn = onnxparser.parse(ann, norm_data.to(parser_device)) # 保存转换好的SNN模型 # Save SNN model torch.save(snn, os.path.join(log_dir,'snn-'+model_name+'.pkl')) fig = plt.figure('simulator') # 定义用于分类的SNN仿真器 # define simulator for classification task sim = classify_simulator(snn, log_dir=log_dir + '/simulator', device=simulator_device, canvas=fig ) # 仿真SNN # Simulate SNN sim.simulate(test_data_loader, T=T, online_drawer=True, ann_acc=ann_acc, fig_name=model_name, step_max=True )
def main(log_dir=None): ''' :return: None 使用Conv-ReLU-[Conv-ReLU]-全连接-ReLU的网络结构训练并转换为SNN,进行MNIST识别。运行示例: .. code-block:: python >>> import spikingjelly.clock_driven.ann2snn.examples.if_cnn_mnist as if_cnn_mnist >>> if_cnn_mnist.main() 输入运行的设备,例如“cpu”或“cuda:0” input device, e.g., "cpu" or "cuda:0": cuda:15 输入保存MNIST数据集的位置,例如“./” input root directory for saving MNIST dataset, e.g., "./": ./mnist 输入batch_size,例如“64” input batch_size, e.g., "64": 128 输入学习率,例如“1e-3” input learning rate, e.g., "1e-3": 1e-3 输入仿真时长,例如“100” input simulating steps, e.g., "100": 100 输入训练轮数,即遍历训练集的次数,例如“10” input training epochs, e.g., "10": 10 输入模型名字,用于自动生成日志文档,例如“mnist” input model name, for log_dir generating , e.g., "mnist" 如果main函数的输入不是具有有效文件的文件夹,自动生成一个日志文件文件夹 If the input of the main function is not a folder with valid files, an automatic log file folder is automatically generated. 第一行输出为保存日志文件的位置,例如“./log-mnist1596804385.476601” Terminal outputs root directory for saving logs, e.g., "./": ./log-mnist1596804385.476601 Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078 Epoch 0 [101/937] ANN Training Loss:1.424 Accuracy:0.669 Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773 Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795 Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788 Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792 Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795 Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.834 Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880 Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.888 Epoch 0 [100/100] ANN Validating Loss:0.328 Accuracy:0.881 Save model to: ./log-mnist1596804385.476601\mnist.pkl ... Epoch 9 [901/937] ANN Training Loss:0.036 Accuracy:0.990 Epoch 9 [100/100] ANN Validating Loss:0.042 Accuracy:0.988 Save model to: ./log-mnist1596804957.0179427\mnist.pkl ''' torch.random.manual_seed(0) torch.cuda.manual_seed(0) device = input( '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ') dataset_dir = input( '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ' ) batch_size = int( input('输入batch_size,例如“64”\n input batch_size, e.g., "64": ')) learning_rate = float( input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": ')) T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": ')) train_epoch = int( input('输入训练轮数,即遍历训练集的次数,例如“10”\n input training epochs, e.g., "10": ')) model_name = input( '输入模型名字,例如“mnist”\n input model name, for log_dir generating , e.g., "mnist": ' ) load = False if log_dir == None: log_dir = './log-' + model_name + str(time.time()) if not os.path.exists(log_dir): os.makedirs(log_dir) print("All the temp files are saved to ", log_dir) else: if not os.path.exists(log_dir): os.makedirs(log_dir) if not os.path.exists(os.path.join(log_dir, model_name + '.pkl')): print('Such log_dir has no model to load.') load = False else: load = True print("All the temp files are saved to ", log_dir) writer = SummaryWriter(log_dir) # 初始化数据加载器 # initialize data loader train_data_dataset = torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True) train_data_loader = torch.utils.data.DataLoader(train_data_dataset, batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=100, shuffle=True, drop_last=False) # 加载默认的配置并保存和输出 # load default configuration, save and print config = utils.Config.default_config print('ann2snn config:\n\t', config) utils.Config.store_config(os.path.join(log_dir, 'default_config.json'), config) ann = ANN().to(device) loss_function = nn.CrossEntropyLoss() if not load: optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4) best_acc = 0.0 for epoch in range(train_epoch): # 使用utils中预先写好的训练程序训练网络 # 训练程序的写法和经典ANN中的训练也是一样的 # Train the network using a pre-prepared code in ''utils'' utils.train_ann(net=ann, device=device, data_loader=train_data_loader, optimizer=optimizer, loss_function=loss_function, epoch=epoch) # 使用utils中预先写好的验证程序验证网络输出 # Validate the network using a pre-prepared code in ''utils'' acc = utils.val_ann(net=ann, device=device, data_loader=test_data_loader, loss_function=loss_function, epoch=epoch) if best_acc <= acc: utils.save_model(ann, log_dir, model_name + '.pkl') writer.add_scalar('val_accuracy', acc, epoch) else: print('Directly load model', model_name + '.pkl') # 加载用于归一化模型的数据 # Load the data to normalize the model norm_set_len = int(train_data_dataset.data.shape[0] / 500) print('Using %d pictures as norm set' % (norm_set_len)) norm_set = train_data_dataset.data[:norm_set_len, :, :].float() / 255 norm_tensor = torch.FloatTensor(norm_set).view(-1, 1, 28, 28) # ANN2SNN标准转化,直接调用可以对模型进行转化并对SNN进行仿真测试 # ANN2SNN standard conversion, direct calling of the function can transform the model and simulate the SNN. utils.pytorch_ann2snn(model_name=model_name, norm_data=norm_tensor, test_data_loader=test_data_loader, device=device, T=T, log_dir=log_dir, config=config)