def main(log_dir=None):
    torch.random.manual_seed(0)
    torch.cuda.manual_seed(0)

    z_norm_mean = (0.485, 0.456, 0.406)
    z_norm_std = (0.229, 0.224, 0.225)

    # example setting
    device = 'cuda:0'
    dataset_dir = 'Dataset/ILSVRC2012'
    batch_size = 64
    learning_rate = 1e-4
    T = 2000
    train_epoch = 40
    model_name = 'imagenetresnet50'

    load = False
    if log_dir == None:
        log_dir = './log-' + model_name + str(time.time())
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    else:
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    print("All the temp files are saved to ", log_dir)

    ann_transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(z_norm_mean, z_norm_std),
    ])

    ann_transform_test = transforms.Compose([
        transforms.Resize(int(224 / 0.875)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(z_norm_mean, z_norm_std),
    ])

    snn_transform = transforms.Compose([
        transforms.Resize(int(224 / 0.875)),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])

    ann_train_data_dataset = torchvision.datasets.ImageFolder(
        root=os.path.join(dataset_dir, 'train'), transform=ann_transform_train)
    snn_train_data_dataset = torchvision.datasets.ImageFolder(
        root=os.path.join(dataset_dir, 'train'), transform=snn_transform)
    ann_train_data_loader = torch.utils.data.DataLoader(
        dataset=ann_train_data_dataset,
        batch_size=batch_size,
        shuffle=True,
        #num_workers=4,
        drop_last=True,
        pin_memory=True)

    ann_test_data_dataset = torchvision.datasets.ImageFolder(
        root=os.path.join(dataset_dir, 'val'), transform=ann_transform_test)
    snn_test_data_dataset = torchvision.datasets.ImageFolder(
        root=os.path.join(dataset_dir, 'val'), transform=snn_transform)
    ann_test_data_loader = torch.utils.data.DataLoader(
        dataset=ann_test_data_dataset,
        batch_size=batch_size,
        shuffle=False,
        #num_workers=4,
        drop_last=False,
        pin_memory=True)
    snn_test_data_loader = torch.utils.data.DataLoader(
        dataset=snn_test_data_dataset,
        batch_size=16,
        shuffle=False,
        #num_workers=4,
        drop_last=False,
        pin_memory=True)

    config = utils.Config.default_config
    print('ann2snn config:\n\t', config)
    utils.Config.store_config(os.path.join(log_dir, 'default_config.json'),
                              config)

    loss_function = nn.CrossEntropyLoss()

    ann = resnet.resnet50().to(device)
    checkpoint_state_dict = torch.load(
        './model_lib/imagenet/checkpoint/ResNet50-state-dict.pth')
    ann.load_state_dict(checkpoint_state_dict)

    # writer = SummaryWriter(log_dir)

    print('Directly load model', model_name + '.pth')

    # 加载用于归一化模型的数据
    # Load the data to normalize the model
    norm_set_len = int(len(snn_train_data_dataset.samples) / 500)
    print('Using %d pictures as norm set' % (norm_set_len))
    norm_set_list = []
    for idx, (datapath, target) in enumerate(snn_train_data_dataset.samples):
        norm_set_list.append(snn_transform(Image.open(datapath)))
        if idx == norm_set_len - 1:
            break
    norm_tensor = torch.stack(norm_set_list)

    ann_acc = utils.val_ann(net=ann,
                            device=device,
                            data_loader=ann_test_data_loader,
                            loss_function=loss_function)

    # def hook(module,input,output):
    #     print(module.__class__.__name__)
    #     print(output.reshape(-1)[10:20])
    #
    # handle = []
    # for m in ann.modules():
    #     handle.append(m.register_forward_hook(hook))

    #print(norm_tensor[10,:,:,:].shape)

    # z_score_layer = nn.BatchNorm2d(num_features=len(z_norm_std))
    # norm_mean = torch.from_numpy(np.array(z_norm_mean).astype(np.float32))
    # norm_std = torch.from_numpy(np.array(z_norm_std).astype(np.float32))
    # z_score_layer.weight.data = torch.ones_like(z_score_layer.weight.data)
    # z_score_layer.bias.data = torch.zeros_like(z_score_layer.bias.data)
    # z_score_layer.running_var.data = torch.pow(norm_std, exponent=2) - z_score_layer.eps
    # z_score_layer.running_mean.data = norm_mean
    # z_score_layer.to('cuda:0')
    # z_score_layer.eval()
    # x = z_score_layer(torch.ones(1,3,224,224).to('cuda:0'))
    # print(x.reshape(-1)[10:20])
    # ann.eval()
    # ann(x)

    # for h in handle:
    #     h.remove()

    utils.onnx_ann2snn(model_name=model_name,
                       ann=ann,
                       norm_tensor=norm_tensor,
                       loss_function=loss_function,
                       test_data_loader=snn_test_data_loader,
                       device=device,
                       T=T,
                       log_dir=log_dir,
                       config=config,
                       z_score=(z_norm_mean, z_norm_std))
Exemplo n.º 2
0
def main(log_dir=None):
    torch.random.manual_seed(0)
    torch.cuda.manual_seed(0)

    train_device = input(
        '输入运行的设备,例如“cpu”或“cuda:0”\n input training device, e.g., "cpu" or "cuda:0": '
    )
    parser_device = input(
        '输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": '
    )
    simulator_device = parser_device
    # simulator_device = input('输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split(',')
    dataset_dir = input(
        '输入保存cifar10数据集的位置,例如“./”\n input root directory for saving cifar10 dataset, e.g., "./": '
    )
    batch_size = int(
        input('输入batch_size,例如“128”\n input batch_size, e.g., "128": '))
    T = int(input('输入仿真时长,例如“400”\n input simulating steps, e.g., "400": '))
    model_name = input(
        '输入模型名字,例如“resnet18_cifar10”\n input model name, for log_dir generating , e.g., "resnet18_cifar10": '
    )

    z_norm_mean = (0.4914, 0.4822, 0.4465)
    z_norm_std = (0.2023, 0.1994, 0.2010)

    load = False
    if log_dir == None:
        from datetime import datetime
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        log_dir = model_name + '-' + current_time
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    else:
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

    if not load:
        writer = SummaryWriter(log_dir)

    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()])

    train_data_dataset = torchvision.datasets.CIFAR10(root=dataset_dir,
                                                      train=True,
                                                      transform=transform,
                                                      download=True)
    train_data_loader = torch.utils.data.DataLoader(dataset=train_data_dataset,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    drop_last=False)
    test_data_dataset = torchvision.datasets.CIFAR10(root=dataset_dir,
                                                     train=False,
                                                     transform=transform,
                                                     download=True)
    test_data_loader = torch.utils.data.DataLoader(dataset=test_data_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   drop_last=False)

    ann = resnet.ResNet18().to(train_device)
    loss_function = nn.CrossEntropyLoss()
    checkpoint_state_dict = torch.load(
        './SJ-cifar10-resnet18_model-sample.pth')
    ann.load_state_dict(checkpoint_state_dict)

    # 加载用于归一化模型的数据
    # Load the data to normalize the model
    percentage = 0.004  # load 0.004 of the data
    norm_data_list = []
    for idx, (imgs, targets) in enumerate(train_data_loader):
        norm_data_list.append(imgs)
        if idx == int(len(train_data_loader) * percentage) - 1:
            break
    norm_data = torch.cat(norm_data_list)
    print('use %d imgs to parse' % (norm_data.size(0)))

    onnxparser = parser(name=model_name,
                        log_dir=log_dir + '/parser',
                        kernel='onnx',
                        z_norm=(z_norm_mean, z_norm_std))

    snn = onnxparser.parse(ann, norm_data.to(parser_device))
    ann_acc = utils.val_ann(
        torch.load(onnxparser.ann_filename).to(train_device), train_device,
        test_data_loader, loss_function)
    torch.save(snn, os.path.join(log_dir, 'snn-' + model_name + '.pkl'))
    fig = plt.figure('simulator')
    sim = classify_simulator(snn,
                             log_dir=log_dir + '/simulator',
                             device=simulator_device,
                             canvas=fig)
    sim.simulate(test_data_loader,
                 T=T,
                 online_drawer=True,
                 ann_acc=ann_acc,
                 fig_name=model_name,
                 step_max=True)
Exemplo n.º 3
0
def main(log_dir=None):
    torch.random.manual_seed(0)
    torch.cuda.manual_seed(0)

    train_device = input(
        '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
    parser_device = input(
        '输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": '
    )
    simulator_device = parser_device
    # simulator_device = input(
    #     '输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split(
    #     ',')
    dataset_dir = input(
        '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving FashionMNIST dataset, e.g., "./": '
    )
    batch_size = int(
        input('输入batch_size,例如“128”\n input batch_size, e.g., "128": '))
    learning_rate = float(
        input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
    T = int(input('输入仿真时长,例如“400”\n input simulating steps, e.g., "400": '))
    train_epoch = int(
        input(
            '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": '))
    model_name = input(
        '输入模型名字,例如“cnn_fashionmnist”\n input model name, for log_dir generating , e.g., "cnn_fashionmnist": '
    )

    load = False
    if log_dir == None:
        from datetime import datetime
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        log_dir = model_name + '-' + current_time
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    else:
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        if not os.path.exists(os.path.join(log_dir, model_name + '.pkl')):
            print('%s has no model to load.' % (log_dir))
            load = False
        else:
            load = True

    if not load:
        writer = SummaryWriter(log_dir)

    # 初始化数据加载器
    # initialize data loader
    train_data_dataset = torchvision.datasets.FashionMNIST(
        root=dataset_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    train_data_loader = torch.utils.data.DataLoader(train_data_dataset,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    drop_last=True)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.FashionMNIST(
            root=dataset_dir,
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=100,
        shuffle=True,
        drop_last=False)

    ann = ANN().to(train_device)
    loss_function = nn.CrossEntropyLoss()
    if not load:
        optimizer = torch.optim.Adam(ann.parameters(),
                                     lr=learning_rate,
                                     weight_decay=5e-4)
        best_acc = 0.0
        for epoch in range(train_epoch):
            # 使用utils中预先写好的训练程序训练网络
            # 训练程序的写法和经典ANN中的训练也是一样的
            # Train the network using a pre-prepared code in ''utils''
            utils.train_ann(net=ann,
                            device=train_device,
                            data_loader=train_data_loader,
                            optimizer=optimizer,
                            loss_function=loss_function,
                            epoch=epoch)
            # 使用utils中预先写好的验证程序验证网络输出
            # Validate the network using a pre-prepared code in ''utils''
            acc = utils.val_ann(net=ann,
                                device=train_device,
                                data_loader=test_data_loader,
                                loss_function=loss_function,
                                epoch=epoch)
            if best_acc <= acc:
                utils.save_model(ann, log_dir, model_name + '.pkl')
            writer.add_scalar('val_accuracy', acc, epoch)
    ann = torch.load(os.path.join(log_dir, model_name + '.pkl'))
    print('validating best model...')
    ann_acc = utils.val_ann(net=ann,
                            device=train_device,
                            data_loader=test_data_loader,
                            loss_function=loss_function)

    # 加载用于归一化模型的数据
    # Load the data to normalize the model
    percentage = 0.004  # load 0.004 of the data
    norm_data_list = []
    for idx, (imgs, targets) in enumerate(train_data_loader):
        norm_data_list.append(imgs)
        if idx == int(len(train_data_loader) * percentage) - 1:
            break
    norm_data = torch.cat(norm_data_list)
    print('use %d imgs to parse' % (norm_data.size(0)))

    onnxparser = parser(name=model_name,
                        log_dir=log_dir + '/parser',
                        kernel='onnx')
    snn = onnxparser.parse(ann, norm_data.to(parser_device))

    torch.save(snn, os.path.join(log_dir, 'snn-' + model_name + '.pkl'))
    fig = plt.figure('simulator')
    sim = classify_simulator(snn,
                             log_dir=log_dir + '/simulator',
                             device=simulator_device,
                             canvas=fig)
    sim.simulate(test_data_loader,
                 T=T,
                 online_drawer=True,
                 ann_acc=ann_acc,
                 fig_name=model_name,
                 step_max=True)
Exemplo n.º 4
0
def main(log_dir=None):
    '''
        :return: None

        使用Conv-ReLU-[Conv-ReLU]-全连接-ReLU的网络结构训练并转换为SNN,进行MNIST识别。运行示例:

        .. code-block:: python

            >>> import spikingjelly.clock_driven.ann2snn.examples.cnn_mnist as cnn_mnist
            >>> cnn_mnist.main()
            输入运行的设备,例如“cpu”或“cuda:0”
             input device, e.g., "cpu" or "cuda:0": cuda:15
            输入保存MNIST数据集的位置,例如“./”
             input root directory for saving MNIST dataset, e.g., "./": ./mnist
            输入batch_size,例如“64”
             input batch_size, e.g., "64": 128
            输入学习率,例如“1e-3”
             input learning rate, e.g., "1e-3": 1e-3
            输入仿真时长,例如“100”
             input simulating steps, e.g., "100": 100
            输入训练轮数,即遍历训练集的次数,例如“10”
             input training epochs, e.g., "10": 10
            输入模型名字,用于自动生成日志文档,例如“cnn_mnist”
             input model name, for log_dir generating , e.g., "cnn_mnist"

            Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078
            Epoch 0 [101/937] ANN Training Loss:1.423 Accuracy:0.669
            Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773
            Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795
            Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788
            Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792
            Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795
            Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.835
            Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880
            Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.889
            100%|██████████| 100/100 [00:00<00:00, 116.12it/s]
            Epoch 0 [100/100] ANN Validating Loss:0.327 Accuracy:0.881
            Save model to: cnn_mnist-XXXXX\cnn_mnist.pkl
            ......
            --------------------simulator summary--------------------
            time elapsed: 46.55072790000008 (sec)
            ---------------------------------------------------------
    '''
    torch.random.manual_seed(0)
    torch.cuda.manual_seed(0)

    train_device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
    parser_device = input('输入分析模型的设备,例如“cpu”或“cuda:0”\n input parsing device, e.g., "cpu" or "cuda:0": ')
    simulator_device = parser_device
    # simulator_device = input(
    #     '输入SNN仿真的设备(支持多线程),例如“cpu,cuda:0”或“cuda:0,cuda:1”\n input SNN simulating device (support multithread), e.g., "cpu,cuda:0" or "cuda:0,cuda:1": ').split(
    #     ',')
    dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ')
    batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
    learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
    T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
    train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“10”\n input training epochs, e.g., "10": '))
    model_name = input('输入模型名字,例如“cnn_mnist”\n input model name, for log_dir generating , e.g., "cnn_mnist": ')

    load = False
    if log_dir == None:
        from datetime import datetime
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        log_dir = model_name+'-'+current_time
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    else:
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        if not os.path.exists(os.path.join(log_dir,model_name+'.pkl')):
            print('%d has no model to load.'%(log_dir))
            load = False
        else:
            load = True

    if not load:
        writer = SummaryWriter(log_dir)

    # 初始化数据加载器
    # initialize data loader
    train_data_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    train_data_loader = torch.utils.data.DataLoader(
        train_data_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.MNIST(
            root=dataset_dir,
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=100,
        shuffle=True,
        drop_last=False)

    ann = ANN().to(train_device)
    loss_function = nn.CrossEntropyLoss()
    if not load:
        optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=5e-4)
        best_acc = 0.0
        for epoch in range(train_epoch):
            # 使用utils中预先写好的训练程序训练网络
            # 训练程序的写法和经典ANN中的训练也是一样的
            # Train the network using a pre-prepared code in ''utils''
            utils.train_ann(net=ann,
                            device=train_device,
                            data_loader=train_data_loader,
                            optimizer=optimizer,
                            loss_function=loss_function,
                            epoch=epoch
                            )
            # 使用utils中预先写好的验证程序验证网络输出
            # Validate the network using a pre-prepared code in ''utils''
            acc = utils.val_ann(net=ann,
                                device=train_device,
                                data_loader=test_data_loader,
                                loss_function=loss_function,
                                epoch=epoch
                                )
            if best_acc <= acc:
                utils.save_model(ann, log_dir, model_name + '.pkl')
            writer.add_scalar('val_accuracy', acc, epoch)
    ann = torch.load(os.path.join(log_dir, model_name + '.pkl'))
    print('validating best model...')
    ann_acc = utils.val_ann(net=ann,
                                device=train_device,
                                data_loader=test_data_loader,
                                loss_function=loss_function
                                )

    # 加载用于归一化模型的数据
    # Load the data to normalize the model
    percentage = 0.004 # load 0.004 of the data
    norm_data_list = []
    for idx, (imgs, targets) in enumerate(train_data_loader):
        norm_data_list.append(imgs)
        if idx == int(len(train_data_loader) * percentage) - 1:
            break
    norm_data = torch.cat(norm_data_list)
    print('use %d imgs to parse' % (norm_data.size(0)))

    # 调用parser,使用kernel为onnx
    # Call parser, use onnx kernel
    onnxparser = parser(name=model_name,
                        log_dir=log_dir + '/parser',
                        kernel='onnx')
    snn = onnxparser.parse(ann, norm_data.to(parser_device))

    # 保存转换好的SNN模型
    # Save SNN model
    torch.save(snn, os.path.join(log_dir,'snn-'+model_name+'.pkl'))
    fig = plt.figure('simulator')

    # 定义用于分类的SNN仿真器
    # define simulator for classification task
    sim = classify_simulator(snn,
                             log_dir=log_dir + '/simulator',
                             device=simulator_device,
                             canvas=fig
                             )
    # 仿真SNN
    # Simulate SNN
    sim.simulate(test_data_loader,
                T=T,
                online_drawer=True,
                ann_acc=ann_acc,
                fig_name=model_name,
                step_max=True
                )
Exemplo n.º 5
0
def main(log_dir=None):
    '''
        :return: None

        使用Conv-ReLU-[Conv-ReLU]-全连接-ReLU的网络结构训练并转换为SNN,进行MNIST识别。运行示例:

        .. code-block:: python

            >>> import spikingjelly.clock_driven.ann2snn.examples.if_cnn_mnist as if_cnn_mnist
            >>> if_cnn_mnist.main()
            输入运行的设备,例如“cpu”或“cuda:0”
             input device, e.g., "cpu" or "cuda:0": cuda:15
            输入保存MNIST数据集的位置,例如“./”
             input root directory for saving MNIST dataset, e.g., "./": ./mnist
            输入batch_size,例如“64”
             input batch_size, e.g., "64": 128
            输入学习率,例如“1e-3”
             input learning rate, e.g., "1e-3": 1e-3
            输入仿真时长,例如“100”
             input simulating steps, e.g., "100": 100
            输入训练轮数,即遍历训练集的次数,例如“10”
             input training epochs, e.g., "10": 10
            输入模型名字,用于自动生成日志文档,例如“mnist”
             input model name, for log_dir generating , e.g., "mnist"

            如果main函数的输入不是具有有效文件的文件夹,自动生成一个日志文件文件夹
            If the input of the main function is not a folder with valid files, an automatic log file folder is automatically generated.
            第一行输出为保存日志文件的位置,例如“./log-mnist1596804385.476601”
             Terminal outputs root directory for saving logs, e.g., "./": ./log-mnist1596804385.476601

            Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078
            Epoch 0 [101/937] ANN Training Loss:1.424 Accuracy:0.669
            Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773
            Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795
            Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788
            Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792
            Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795
            Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.834
            Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880
            Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.888
            Epoch 0 [100/100] ANN Validating Loss:0.328 Accuracy:0.881
            Save model to: ./log-mnist1596804385.476601\mnist.pkl
            ...
            Epoch 9 [901/937] ANN Training Loss:0.036 Accuracy:0.990
            Epoch 9 [100/100] ANN Validating Loss:0.042 Accuracy:0.988
            Save model to: ./log-mnist1596804957.0179427\mnist.pkl
    '''
    torch.random.manual_seed(0)
    torch.cuda.manual_seed(0)

    device = input(
        '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
    dataset_dir = input(
        '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": '
    )
    batch_size = int(
        input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
    learning_rate = float(
        input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
    T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
    train_epoch = int(
        input('输入训练轮数,即遍历训练集的次数,例如“10”\n input training epochs, e.g., "10": '))
    model_name = input(
        '输入模型名字,例如“mnist”\n input model name, for log_dir generating , e.g., "mnist": '
    )

    load = False
    if log_dir == None:
        log_dir = './log-' + model_name + str(time.time())
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        print("All the temp files are saved to ", log_dir)
    else:
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        if not os.path.exists(os.path.join(log_dir, model_name + '.pkl')):
            print('Such log_dir has no model to load.')
            load = False
        else:
            load = True
        print("All the temp files are saved to ", log_dir)

    writer = SummaryWriter(log_dir)

    # 初始化数据加载器
    # initialize data loader
    train_data_dataset = torchvision.datasets.MNIST(
        root=dataset_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)
    train_data_loader = torch.utils.data.DataLoader(train_data_dataset,
                                                    batch_size=batch_size,
                                                    shuffle=True,
                                                    drop_last=True)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.MNIST(
            root=dataset_dir,
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=100,
        shuffle=True,
        drop_last=False)

    # 加载默认的配置并保存和输出
    # load default configuration, save and print
    config = utils.Config.default_config
    print('ann2snn config:\n\t', config)
    utils.Config.store_config(os.path.join(log_dir, 'default_config.json'),
                              config)

    ann = ANN().to(device)
    loss_function = nn.CrossEntropyLoss()
    if not load:
        optimizer = torch.optim.Adam(ann.parameters(),
                                     lr=learning_rate,
                                     weight_decay=5e-4)
        best_acc = 0.0
        for epoch in range(train_epoch):
            # 使用utils中预先写好的训练程序训练网络
            # 训练程序的写法和经典ANN中的训练也是一样的
            # Train the network using a pre-prepared code in ''utils''
            utils.train_ann(net=ann,
                            device=device,
                            data_loader=train_data_loader,
                            optimizer=optimizer,
                            loss_function=loss_function,
                            epoch=epoch)
            # 使用utils中预先写好的验证程序验证网络输出
            # Validate the network using a pre-prepared code in ''utils''
            acc = utils.val_ann(net=ann,
                                device=device,
                                data_loader=test_data_loader,
                                loss_function=loss_function,
                                epoch=epoch)
            if best_acc <= acc:
                utils.save_model(ann, log_dir, model_name + '.pkl')
            writer.add_scalar('val_accuracy', acc, epoch)
    else:
        print('Directly load model', model_name + '.pkl')

    # 加载用于归一化模型的数据
    # Load the data to normalize the model
    norm_set_len = int(train_data_dataset.data.shape[0] / 500)
    print('Using %d pictures as norm set' % (norm_set_len))
    norm_set = train_data_dataset.data[:norm_set_len, :, :].float() / 255
    norm_tensor = torch.FloatTensor(norm_set).view(-1, 1, 28, 28)

    # ANN2SNN标准转化,直接调用可以对模型进行转化并对SNN进行仿真测试
    # ANN2SNN standard conversion, direct calling of the function can transform the model and simulate the SNN.
    utils.pytorch_ann2snn(model_name=model_name,
                          norm_data=norm_tensor,
                          test_data_loader=test_data_loader,
                          device=device,
                          T=T,
                          log_dir=log_dir,
                          config=config)