def main(): # device = input('输入运行的设备,例如“cpu”或“cuda:0” ') # dataset_dir = input('输入保存MNIST数据集的位置,例如“./” ') # class_num = int(input('输入class_num,例如“10” ')) # T = int(input('输入仿真时长,例如“50” ')) # phase = input('输入算法阶段,例如“BIM” ') device = 'cuda:3' dataset_dir = '../../dataset/' class_num = 10 T = 50 phase = 'BIM' torch.cuda.empty_cache() encoder = encoding.PoissonEncoder() if phase == 'BIM': # model_path = input('输入模型文件路径,例如“./model.pth” ') # iter_num = int(input('输入对抗攻击的迭代次数,例如“25” ')) # eta = float(input('输入对抗攻击学习率,例如“0.05” ')) # attack_type = input('输入攻击类型,例如“UT/T” ') # clip_eps = float(input('输入截断eps,例如“0.01” ')) source_model_path = './models/cifar10_spike_v1.pth' target_model_path = './models/cifar10_spike_v2.pth' iter_num = 25 eta = 0.03 attack_type = 'UT' clip_eps = 0.35 transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10( root=dataset_dir, train=False, transform=transform_test, download=True), batch_size=1, shuffle=False, drop_last=False) p_max = transform_test(np.ones((32, 32, 3))).to(device) p_min = transform_test(np.zeros((32, 32, 3))).to(device) source_net = Net().to(device) source_net.load_state_dict(torch.load(source_model_path)) target_net = Net().to(device) target_net.load_state_dict(torch.load(target_model_path)) target_net.eval() mean_p = 0.0 test_sum = 0 source_success_sum = 0 target_success_sum = 0 if attack_type == 'UT': for X, y in test_data_loader: img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) test_sum += 1 print('Img %d' % test_sum) source_net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = source_net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += source_net(spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() rate = torch.zeros_like(spike).to(device) for spike in spike_train: rate += spike.grad.data img_adv = clip_by_tensor(img + eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max) source_net.reset_() for p in spike_train: p.grad.data.zero_() for p in source_net.parameters(): p.grad.data.zero_() source_net.eval() with torch.no_grad(): img_diff = img - img_ori l_norm = torch.max(torch.abs(img_diff)).item() print('Perturbation: %f' % l_norm) mean_p += l_norm for t in range(T): if t == 0: source_out_spikes_counter = source_net(encoder(img).float()).unsqueeze(0) target_out_spikes_counter = target_net(encoder(img).float()).unsqueeze(0) else: source_out_spikes_counter += source_net(encoder(img).float()).unsqueeze(0) target_out_spikes_counter += target_net(encoder(img).float()).unsqueeze(0) source_out_spikes_counter_frequency = source_out_spikes_counter / T target_out_spikes_counter_frequency = target_out_spikes_counter / T source_attack_flag = (source_out_spikes_counter.max(1)[1] != label).float().sum().item() source_success_sum += source_attack_flag target_attack_flag = (target_out_spikes_counter.max(1)[1] != label).float().sum().item() target_success_sum += target_attack_flag source_net.reset_() target_net.reset_() if source_attack_flag > 0.5: print('Source Attack Success') else: print('Source Attack Failure') if target_attack_flag > 0.5: print('Target Attack Success') else: print('Target Attack Failure') if test_sum >= 250: mean_p /= 250 break else: for X, y in test_data_loader: for i in range(1, class_num): img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) target_label = (label + i) % class_num test_sum += 1 source_net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = source_net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += source_net(spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, target_label) loss.backward() rate = torch.zeros_like(spike).to(device) for spike in spike_train: rate += spike.grad.data img_grad = torch.sign(rate) img_adv = clip_by_tensor(img - eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max) source_net.reset_() for p in spike_train: p.grad.data.zero_() for p in source_net.parameters(): p.grad.data.zero_() source_net.eval() with torch.no_grad(): img_diff = img - img_ori l_norm = torch.max(torch.abs(img_diff)).item() print('Perturbation: %f' % l_norm) mean_p += l_norm for t in range(T): if t == 0: source_out_spikes_counter = source_net(encoder(img).float()).unsqueeze(0) target_out_spikes_counter = target_net(encoder(img).float()).unsqueeze(0) else: source_out_spikes_counter += source_net(encoder(img).float()).unsqueeze(0) target_out_spikes_counter += target_net(encoder(img).float()).unsqueeze(0) source_out_spikes_counter_frequency = source_out_spikes_counter / T target_out_spikes_counter_frequency = target_out_spikes_counter / T source_attack_flag = (source_out_spikes_counter.max(1)[1] == target_label).float().sum().item() source_success_sum += source_attack_flag target_attack_flag = (target_out_spikes_counter.max(1)[1] == target_label).float().sum().item() target_success_sum += target_attack_flag source_net.reset_() target_net.reset_() if source_attack_flag > 0.5: print('Source Attack Success') else: print('Source Attack Failure') if target_attack_flag > 0.5: print('Target Attack Success') else: print('Target Attack Failure') ''' samples = img.permute(0, 2, 3, 1).data.cpu().numpy() im = np.repeat(samples[0], 3, axis=2) im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item()) print(im_path) print(out_spikes_counter_frequency) plt.imsave(im_path, im) ''' if test_sum >= 270: mean_p /= 270 break print('Mean Perturbation: %.3f' % mean_p) print('source_success_sum: %d' % source_success_sum) print('target_success_sum: %d' % target_success_sum) print('test_sum: %d' % test_sum) print('source_success_rate: %.2f%%' % (100 * source_success_sum / test_sum)) print('target_success_rate: %.2f%%' % (100 * target_success_sum / test_sum))
def main(): # device = input('输入运行的设备,例如“cpu”或“cuda:0” ') # dataset_dir = input('输入保存MNIST数据集的位置,例如“./” ') # class_num = int(input('输入class_num,例如“10” ')) # lr = float(input('输入学习率,例如“1e-3” ')) # T = int(input('输入仿真时长,例如“50” ')) # phase = input('输入算法阶段,例如“train/BIM” ') device = 'cuda:1' dataset_dir = '../../dataset/' class_num = 10 T = 50 phase = 'BIM' torch.cuda.empty_cache() encoder = encoding.PoissonEncoder() if phase == 'BIM': # model_path = input('输入模型文件路径,例如“./model.pth” ') # iter_num = int(input('输入对抗攻击的迭代次数,例如“25” ')) # eta = float(input('输入对抗攻击学习率,例如“0.05” ')) # attack_type = input('输入攻击类型,例如“UT/T” ') # clip_flag = bool(input('输入是否使用截断,例如“True/False” ')) # clip_eps = float(input('输入截断eps,例如“0.01” ')) model_path = './models/mnist_spike_v1.pth' iter_num = 100 eta = 0.02 attack_type = 'T' clip_flag = True clip_eps = 0.4 test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=1, shuffle=False, drop_last=False) net = Net().to(device) net.load_state_dict(torch.load(model_path)) mean_p = 0.0 test_sum = 0 success_sum = 0 if attack_type == 'UT': for X, y in test_data_loader: img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) test_sum += 1 print('Img %d' % test_sum) net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += net(spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() rate = torch.zeros_like(spike).to(device) for spike in spike_train: rate += spike.grad.data img_grad = torch.sign(rate) if clip_flag: img = clip_by_tensor(img + eta * img_grad, img_ori - clip_eps, img_ori + clip_eps) else: img = torch.clamp(img + eta * img_grad, 0.0, 1.0) net.reset_() for p in spike_train: p.grad.data.zero_() for p in net.parameters(): p.grad.data.zero_() net.eval() with torch.no_grad(): img_diff = img - img_ori l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item() print('Total Perturbation: %f' % l2_norm) mean_p += l2_norm for t in range(T): if t == 0: out_spikes_counter = net( encoder(img).float()).unsqueeze(0) else: out_spikes_counter += net( encoder(img).float()).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T attack_flag = (out_spikes_counter.max(1)[1] != label).float().sum().item() success_sum += attack_flag if attack_flag > 0.5: print('Attack Success') else: print('Attack Failure') if test_sum >= 250: mean_p /= 250 break else: for X, y in test_data_loader: for i in range(1, class_num): img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) target_label = (label + i) % class_num test_sum += 1 net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += net(spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, target_label) loss.backward() rate = torch.zeros_like(spike).to(device) for spike in spike_train: rate += spike.grad.data img_grad = torch.sign(rate) if clip_flag: img = clip_by_tensor(img - eta * img_grad, img_ori - clip_eps, img_ori + clip_eps) else: img = torch.clamp(img - eta * img_grad, 0.0, 1.0) net.reset_() for p in spike_train: p.grad.data.zero_() for p in net.parameters(): p.grad.data.zero_() net.eval() with torch.no_grad(): img_diff = img - img_ori l2_norm = torch.norm(img_diff.view( img_diff.size()[0], -1), dim=1).item() print('Total Perturbation: %f' % l2_norm) mean_p += l2_norm for t in range(T): if t == 0: out_spikes_counter = net( encoder(img).float()).unsqueeze(0) else: out_spikes_counter += net( encoder(img).float()).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T attack_flag = (out_spikes_counter.max(1)[1] == target_label).float().sum().item() success_sum += attack_flag if attack_flag > 0.5: print('Attack Success') else: print('Attack Failure') ''' samples = img.permute(0, 2, 3, 1).data.cpu().numpy() im = np.repeat(samples[0], 3, axis=2) im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item()) print(im_path) print(out_spikes_counter_frequency) plt.imsave(im_path, im) ''' if test_sum >= 270: mean_p /= 270 break print('Mean Perturbation: %.2f' % mean_p) print('success_sum: %d' % success_sum) print('test_sum: %d' % test_sum) print('success_rate: %.2f%%' % (100 * success_sum / test_sum))
def main(): # device = input('输入运行的设备,例如“cpu”或“cuda:0” ') # dataset_dir = input('输入保存MNIST数据集的位置,例如“./” ') # class_num = int(input('输入class_num,例如“10” ')) # lr = float(input('输入学习率,例如“1e-3” ')) # T = int(input('输入仿真时长,例如“50” ')) # phase = input('输入算法阶段,例如“train/BIM” ') device = 'cuda:3' dataset_dir = '../../dataset/' class_num = 10 lr = 1e-4 T = 8 phase = 'train' torch.cuda.empty_cache() encoder = encoding.PoissonEncoder() if phase == 'train': # model_dir = input('输入保存模型文件的位置,例如“./” ') # batch_size = int(input('输入batch_size,例如“64” ')) # train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100” ')) # log_dir = input('输入保存tensorboard日志文件的位置,例如“./” ') model_dir = './models/' batch_size = 64 train_epoch = 9999999 log_dir = './logs/' writer = SummaryWriter(log_dir) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10(root=dataset_dir, train=True, transform=transform_train, download=True), batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10(root=dataset_dir, train=False, transform=transform_test, download=True), batch_size=batch_size, shuffle=True, drop_last=False) net = Net().to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) train_times = 0 best_epoch = 0 max_correct_sum = 0 for epoch in range(1, train_epoch + 1): net.train() for X, y in train_data_loader: img, label = X.to(device), y.to(device) optimizer.zero_grad() for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) out_spikes_counter_frequency = out_spikes_counter / T loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float()) # loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() optimizer.step() net.reset_() correct_rate = (out_spikes_counter_frequency.max(1)[1] == label ).float().mean().item() writer.add_scalar('train_correct_rate', correct_rate, train_times) # if train_times % 1024 == 0: # print(device, dataset_dir, batch_size, lr, T, train_epoch, log_dir) # print(sys.argv, 'train_times', train_times, 'train_correct_rate', correct_rate) train_times += 1 net.eval() with torch.no_grad(): test_sum = 0 correct_sum = 0 for X, y in test_data_loader: img, label = X.to(device), y.to(device) for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter = net(encoder(img).float()) correct_sum += (out_spikes_counter.max(1)[1] == label ).float().sum().item() test_sum += label.numel() net.reset_() writer.add_scalar('test_correct_rate', correct_sum / test_sum, train_times) print('epoch', epoch, 'test_correct_rate', correct_sum / test_sum) if correct_sum > max_correct_sum: max_correct_sum = correct_sum torch.save(net.state_dict(), model_dir + 'spike_best_%d.pth' % (epoch)) if best_epoch > 0: os.system('rm %sspike_best_%d.pth' % (model_dir, best_epoch)) best_epoch = epoch elif phase == 'BIM': # model_path = input('输入模型文件路径,例如“./model.pth” ') # iter_num = int(input('输入对抗攻击的迭代次数,例如“25” ')) # eta = float(input('输入对抗攻击学习率,例如“0.05” ')) # attack_type = input('输入攻击类型,例如“UT/T” ') # clip_eps = float(input('输入截断eps,例如“0.01” ')) model_path = './models/cifar10_spike_v1.pth' iter_num = 50 eta = 0.03 attack_type = 'UT' clip_eps = 0.6 transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10(root=dataset_dir, train=False, transform=transform_test, download=True), batch_size=1, shuffle=False, drop_last=False) p_max = transform_test(np.ones((32, 32, 3))).to(device) p_min = transform_test(np.zeros((32, 32, 3))).to(device) net = Net().to(device) net.load_state_dict(torch.load(model_path)) mean_p = 0.0 test_sum = 0 success_sum = 0 if attack_type == 'UT': for X, y in test_data_loader: img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) test_sum += 1 print('Img %d' % test_sum) net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += net(spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() rate = torch.zeros_like(spike).to(device) for spike in spike_train: rate += spike.grad.data img_grad = torch.sign(rate) img_adv = clip_by_tensor(img + eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max) net.reset_() for p in spike_train: p.grad.data.zero_() for p in net.parameters(): p.grad.data.zero_() net.eval() with torch.no_grad(): img_diff = img - img_ori l_norm = torch.max(torch.abs(img_diff)).item() print('Perturbation: %f' % l_norm) mean_p += l_norm for t in range(T): if t == 0: out_spikes_counter = net( encoder(img).float()).unsqueeze(0) else: out_spikes_counter += net( encoder(img).float()).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T attack_flag = (out_spikes_counter.max(1)[1] != label).float().sum().item() success_sum += attack_flag if attack_flag > 0.5: print('Attack Success') else: print('Attack Failure') if test_sum >= 250: mean_p /= 250 break else: for X, y in test_data_loader: for i in range(1, class_num): img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) target_label = (label + i) % class_num test_sum += 1 net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += net(spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, target_label) loss.backward() rate = torch.zeros_like(spike).to(device) for spike in spike_train: rate += spike.grad.data img_grad = torch.sign(rate) img_adv = clip_by_tensor(img - eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max) net.reset_() for p in spike_train: p.grad.data.zero_() for p in net.parameters(): p.grad.data.zero_() net.eval() with torch.no_grad(): img_diff = img - img_ori l_norm = torch.max(torch.abs(img_diff)).item() print('Perturbation: %f' % l_norm) mean_p += l_norm for t in range(T): if t == 0: out_spikes_counter = net( encoder(img).float()).unsqueeze(0) else: out_spikes_counter += net( encoder(img).float()).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T attack_flag = (out_spikes_counter.max(1)[1] == target_label).float().sum().item() success_sum += attack_flag if attack_flag > 0.5: print('Attack Success') else: print('Attack Failure') ''' samples = img.permute(0, 2, 3, 1).data.cpu().numpy() im = np.repeat(samples[0], 3, axis=2) im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item()) print(im_path) print(out_spikes_counter_frequency) plt.imsave(im_path, im) ''' if test_sum >= 270: mean_p /= 270 break print('Mean Perturbation: %.3f' % mean_p) print('success_sum: %d' % success_sum) print('test_sum: %d' % test_sum) print('success_rate: %.2f%%' % (100 * success_sum / test_sum))
def main(): # device = input('输入运行的设备,例如“cpu”或“cuda:0” ') # dataset_dir = input('输入保存MNIST数据集的位置,例如“./” ') # class_num = int(input('输入class_num,例如“10” ')) # lr = float(input('输入学习率,例如“1e-3” ')) # T = int(input('输入仿真时长,例如“50” ')) # phase = input('输入算法阶段,例如“train/BIM” ') device = 'cuda:1' dataset_dir = '../../dataset/' class_num = 10 lr = 1e-4 T = 50 phase = 'BIM' torch.cuda.empty_cache() encoder = encoding.PoissonEncoder() if phase == 'train': # model_dir = input('输入保存模型文件的位置,例如“./” ') # batch_size = int(input('输入batch_size,例如“64” ')) # train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100” ')) # log_dir = input('输入保存tensorboard日志文件的位置,例如“./” ') model_dir = './models/' batch_size = 64 train_epoch = 9999999 log_dir = './logs/' writer = SummaryWriter(log_dir) train_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True), batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=batch_size, shuffle=True, drop_last=False) net = Net().to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) train_times = 0 best_epoch = 0 max_correct_sum = 0 for epoch in range(1, train_epoch + 1): net.train() for X, y in train_data_loader: img, label = X.to(device), y.to(device) optimizer.zero_grad() for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) out_spikes_counter_frequency = out_spikes_counter / T loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float()) # loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() optimizer.step() net.reset_() correct_rate = (out_spikes_counter_frequency.max(1)[1] == label).float().mean().item() writer.add_scalar('train_correct_rate', correct_rate, train_times) # if train_times % 1024 == 0: # print(device, dataset_dir, batch_size, lr, T, train_epoch, log_dir) # print(sys.argv, 'train_times', train_times, 'train_correct_rate', correct_rate) train_times += 1 net.eval() with torch.no_grad(): test_sum = 0 correct_sum = 0 for X, y in test_data_loader: img, label = X.to(device), y.to(device) for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter = net(encoder(img).float()) correct_sum += (out_spikes_counter.max(1)[1] == label).float().sum().item() test_sum += label.numel() net.reset_() writer.add_scalar('test_correct_rate', correct_sum / test_sum, train_times) print('epoch', epoch, 'test_correct_rate', correct_sum / test_sum) if correct_sum > max_correct_sum: max_correct_sum = correct_sum torch.save(net.state_dict(), model_dir + 'spike_best_%d.pth' % (epoch)) if best_epoch > 0: os.system('rm %sspike_best_%d.pth' % (model_dir, best_epoch)) best_epoch = epoch elif phase == 'BIM': # model_path = input('输入模型文件路径,例如“./model.pth” ') # iter_num = int(input('输入对抗攻击的迭代次数,例如“25” ')) # gamma = float(input('输入GT的采样因子,例如“0.05” ')) # perturbation = float(input('输入扰动幅度,例如“4.0” ')) # attack_type = input('输入攻击类型,例如“UT/T” ') model_path = './models/mnist_spike_v1.pth' gamma = 0.05 iter_num = 50 perturbation = 3.1 attack_type = 'T' test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=1, shuffle=False, drop_last=False) net = Net().to(device) net.load_state_dict(torch.load(model_path)) mean_p = 0.0 test_sum = 0 success_sum = 0 if attack_type == 'UT': for X, y in test_data_loader: img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) test_sum += 1 print('Img %d' % test_sum) net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += net(spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() ik = torch.zeros_like(spike).to(device) for spike in spike_train: if torch.max(torch.abs(spike.grad.data)) > 1e-32: # print('G2S Converter') grad_sign = torch.sign(spike.grad.data) grad_abs = torch.abs(spike.grad.data) grad_norm = (grad_abs - torch.min(grad_abs)) / (torch.max(grad_abs) - torch.min(grad_abs)) grad_mask = torch.bernoulli(grad_norm) G2S = grad_sign * grad_mask G2S_trans = torch.clamp(G2S + spike, 0.0, 1.0) - spike ik += G2S_trans else: # print('Gradient Trigger') GT = torch.bernoulli(torch.ones_like(spike.grad.data) * gamma) GT_trans = (GT.bool() ^ spike.bool()).float() - spike ik += GT_trans ik /= T l2_norm = torch.norm(ik.view(ik.size()[0], -1), dim=1).item() # print('Perturbation: %f' % l2_norm) if l2_norm < perturbation: img = torch.clamp(img + ik, 0.0, 1.0) net.reset_() for p in spike_train: p.grad.data.zero_() for p in net.parameters(): p.grad.data.zero_() else: net.reset_() for p in spike_train: p.grad.data.zero_() for p in net.parameters(): p.grad.data.zero_() net.eval() with torch.no_grad(): img_diff = img - img_ori l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item() print('Total Perturbation: %f' % l2_norm) mean_p += l2_norm for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()).unsqueeze(0) else: out_spikes_counter += net(encoder(img).float()).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T attack_flag = (out_spikes_counter.max(1)[1] != label).float().sum().item() success_sum += attack_flag if attack_flag > 0.5: print('Attack Success') else: print('Attack Failure') if test_sum >= 250: mean_p /= 250 break else: for X, y in test_data_loader: for i in range(1, class_num): img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) target_label = (label + i) % class_num test_sum += 1 net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += net(spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float()) # loss = F.cross_entropy(out_spikes_counter_frequency, target_label) loss.backward() ik = torch.zeros_like(spike).to(device) for spike in spike_train: if torch.max(torch.abs(spike.grad.data)) > 1e-32: # print('G2S Converter') grad_sign = -torch.sign(spike.grad.data) grad_abs = torch.abs(spike.grad.data) grad_norm = (grad_abs - torch.min(grad_abs)) / (torch.max(grad_abs) - torch.min(grad_abs)) grad_mask = torch.bernoulli(grad_norm) G2S = grad_sign * grad_mask G2S_trans = torch.clamp(G2S + spike, 0.0, 1.0) - spike ik += G2S_trans else: # print('Gradient Trigger') GT = torch.bernoulli(torch.ones_like(spike.grad.data) * gamma) GT_trans = (GT.bool() ^ spike.bool()).float() - spike ik += GT_trans ik /= T l2_norm = torch.norm(ik.view(ik.size()[0], -1), dim=1).item() # print('Perturbation: %f' % l2_norm) if l2_norm < perturbation: img = torch.clamp(img + ik, 0.0, 1.0) net.reset_() for p in spike_train: p.grad.data.zero_() for p in net.parameters(): p.grad.data.zero_() else: net.reset_() for p in spike_train: p.grad.data.zero_() for p in net.parameters(): p.grad.data.zero_() net.eval() with torch.no_grad(): img_diff = img - img_ori l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item() print('Total Perturbation: %f' % l2_norm) mean_p += l2_norm for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()).unsqueeze(0) else: out_spikes_counter += net(encoder(img).float()).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T attack_flag = (out_spikes_counter.max(1)[1] == target_label).float().sum().item() success_sum += attack_flag if attack_flag > 0.5: print('Attack Success') else: print('Attack Failure') ''' samples = img.permute(0, 2, 3, 1).data.cpu().numpy() im = np.repeat(samples[0], 3, axis=2) im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item()) print(im_path) print(out_spikes_counter_frequency) plt.imsave(im_path, im) ''' if test_sum >= 270: mean_p /= 270 break print('Mean Perturbation: %.2f' % mean_p) print('success_sum: %d' % success_sum) print('test_sum: %d' % test_sum) print('success_rate: %.2f%%' % (100 * success_sum / test_sum))
def main(): # device = input('输入运行的设备,例如“cpu”或“cuda:0” ') # dataset_dir = input('输入保存MNIST数据集的位置,例如“./” ') # class_num = int(input('输入class_num,例如“10” ')) # T = int(input('输入仿真时长,例如“50” ')) # phase = input('输入算法阶段,例如“BIM” ') device = 'cuda:3' dataset_dir = '../../dataset/' class_num = 10 T = 50 phase = 'BIM' torch.cuda.empty_cache() encoder = encoding.PoissonEncoder() if phase == 'BIM': # model_path = input('输入模型文件路径,例如“./model.pth” ') # iter_num = int(input('输入对抗攻击的迭代次数,例如“25” ')) # eta = float(input('输入对抗攻击学习率,例如“0.05” ')) # attack_type = input('输入攻击类型,例如“UT/T” ') # clip_flag = bool(input('输入是否使用截断,例如“True/False” ')) # clip_eps = float(input('输入截断eps,例如“0.01” ')) source_model_path = './models/mnist_img_v1.pth' target_model_path1 = './models/mnist_ann_v1.pth' target_model_path2 = './models/mnist_spike_v1.pth' iter_num = 25 eta = 0.02 attack_type = 'T' clip_flag = True clip_eps = 0.3 test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=1, shuffle=False, drop_last=False) source_net = SNN_Net().to(device) source_net.load_state_dict(torch.load(source_model_path)) target_net1 = ANN_Net().to(device) target_net1.load_state_dict(torch.load(target_model_path1)) target_net2 = SNN_Net().to(device) target_net2.load_state_dict(torch.load(target_model_path2)) target_net1.eval() target_net2.eval() mean_p = 0.0 test_sum = 0 source_success_sum = 0 target_success_sum1 = 0 target_success_sum2 = 0 if attack_type == 'UT': for X, y in test_data_loader: img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) img.requires_grad = True test_sum += 1 print('Img %d' % test_sum) source_net.train() for it in range(iter_num): for t in range(T): if t == 0: out_spikes_counter = source_net(img).unsqueeze(0) else: out_spikes_counter += source_net(img).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() img_grad = torch.sign(img.grad.data) img_adv = None if clip_flag: img_adv = clip_by_tensor(img + eta * img_grad, img_ori - clip_eps, img_ori + clip_eps) else: img_adv = torch.clamp(img + eta * img_grad, 0.0, 1.0) img = Variable(img_adv, requires_grad=True) source_net.reset_() source_net.eval() with torch.no_grad(): img_diff = img - img_ori l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item() print('Perturbation: %f' % l2_norm) mean_p += l2_norm target_output1 = target_net1(img).unsqueeze(0) for t in range(T): if t == 0: source_out_spikes_counter = source_net( img).unsqueeze(0) target_out_spikes_counter2 = target_net2( encoder(img).float()).unsqueeze(0) else: source_out_spikes_counter += source_net( img).unsqueeze(0) target_out_spikes_counter2 += target_net2( encoder(img).float()).unsqueeze(0) source_output = source_out_spikes_counter / T target_output2 = target_out_spikes_counter2 / T source_attack_flag = (source_output.max(1)[1] != label).float().sum().item() source_success_sum += source_attack_flag target_attack_flag1 = (target_output1.max(1)[1] != label).float().sum().item() target_success_sum1 += target_attack_flag1 target_attack_flag2 = (target_output2.max(1)[1] != label).float().sum().item() target_success_sum2 += target_attack_flag2 source_net.reset_() target_net2.reset_() if source_attack_flag > 0.5: print('Source Attack Success') else: print('Source Attack Failure') if target_attack_flag1 > 0.5: print('Target Attack 1 Success') else: print('Target Attack 1 Failure') if target_attack_flag2 > 0.5: print('Target Attack 2 Success') else: print('Target Attack 2 Failure') if test_sum >= 250: mean_p /= 250 break else: for X, y in test_data_loader: for i in range(1, class_num): img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) img.requires_grad = True target_label = (label + i) % class_num test_sum += 1 source_net.train() for it in range(iter_num): for t in range(T): if t == 0: out_spikes_counter = source_net(img).unsqueeze( 0) else: out_spikes_counter += source_net( img).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, target_label) loss.backward() img_grad = torch.sign(img.grad.data) img_adv = None if clip_flag: img_adv = clip_by_tensor(img - eta * img_grad, img_ori - clip_eps, img_ori + clip_eps) else: img_adv = torch.clamp(img - eta * img_grad, 0.0, 1.0) img = Variable(img_adv, requires_grad=True) source_net.reset_() source_net.eval() with torch.no_grad(): img_diff = img - img_ori l2_norm = torch.norm(img_diff.view( img_diff.size()[0], -1), dim=1).item() print('Perturbation: %f' % l2_norm) mean_p += l2_norm target_output1 = target_net1(img).unsqueeze(0) for t in range(T): if t == 0: source_out_spikes_counter = source_net( img).unsqueeze(0) target_out_spikes_counter2 = target_net2( encoder(img).float()).unsqueeze(0) else: source_out_spikes_counter += source_net( img).unsqueeze(0) target_out_spikes_counter2 += target_net2( encoder(img).float()).unsqueeze(0) source_output = source_out_spikes_counter / T target_output2 = target_out_spikes_counter2 / T source_attack_flag = (source_output.max( 1)[1] == target_label).float().sum().item() source_success_sum += source_attack_flag target_attack_flag1 = (target_output1.max( 1)[1] == target_label).float().sum().item() target_success_sum1 += target_attack_flag1 target_attack_flag2 = (target_output2.max( 1)[1] == target_label).float().sum().item() target_success_sum2 += target_attack_flag2 source_net.reset_() target_net2.reset_() if source_attack_flag > 0.5: print('Source Attack Success') else: print('Source Attack Failure') if target_attack_flag1 > 0.5: print('Target Attack 1 Success') else: print('Target Attack 1 Failure') if target_attack_flag2 > 0.5: print('Target Attack 2 Success') else: print('Target Attack 2 Failure') if test_sum >= 270: mean_p /= 270 break print('Mean Perturbation: %.2f' % mean_p) print('source_success_sum: %d' % source_success_sum) print('target_success_1_sum: %d' % target_success_sum1) print('target_success_2_sum: %d' % target_success_sum2) print('test_sum: %d' % test_sum) print('source_success_rate: %.2f%%' % (100 * source_success_sum / test_sum)) print('target_success_1_rate: %.2f%%' % (100 * target_success_sum1 / test_sum)) print('target_success_2_rate: %.2f%%' % (100 * target_success_sum2 / test_sum))
def main(): # device = input('输入运行的设备,例如“cpu”或“cuda:0” ') # dataset_dir = input('输入保存MNIST数据集的位置,例如“./” ') # class_num = int(input('输入class_num,例如“10” ')) # T = int(input('输入仿真时长,例如“50” ')) # phase = input('输入算法阶段,例如“BIM” ') device = 'cuda:3' dataset_dir = '../../dataset/' class_num = 10 T = 50 phase = 'BIM' torch.cuda.empty_cache() encoder = encoding.PoissonEncoder() if phase == 'BIM': # model_path = input('输入模型文件路径,例如“./model.pth” ') # iter_num = int(input('输入对抗攻击的迭代次数,例如“25” ')) # eta = float(input('输入对抗攻击学习率,例如“0.05” ')) # attack_type = input('输入攻击类型,例如“UT/T” ') # clip_eps = float(input('输入截断eps,例如“0.01” ')) source_model_path = './models/cifar10_ann_v1.pth' target_model_path1 = './models/cifar10_img_v1.pth' target_model_path2 = './models/cifar10_spike_v1.pth' iter_num = 25 eta = 0.003 attack_type = 'UT' clip_eps = 0.06 transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10(root=dataset_dir, train=False, transform=transform_test, download=True), batch_size=1, shuffle=False, drop_last=False) p_max = transform_test(np.ones((32, 32, 3))).to(device) p_min = transform_test(np.zeros((32, 32, 3))).to(device) source_net = ANN_Net().to(device) source_net.load_state_dict(torch.load(source_model_path)) target_net1 = SNN_Net().to(device) target_net1.load_state_dict(torch.load(target_model_path1)) target_net2 = SNN_Net().to(device) target_net2.load_state_dict(torch.load(target_model_path2)) target_net1.eval() target_net2.eval() mean_p = 0.0 test_sum = 0 source_success_sum = 0 target_success_sum1 = 0 target_success_sum2 = 0 if attack_type == 'UT': for X, y in test_data_loader: img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) img.requires_grad = True test_sum += 1 print('Img %d' % test_sum) source_net.train() for it in range(iter_num): output = source_net(img).unsqueeze(0) # loss = F.mse_loss(output, F.one_hot(label, class_num).float()) loss = F.cross_entropy(output, label) loss.backward() img_grad = torch.sign(img.grad.data) img_adv = clip_by_tensor(img - eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max) img = Variable(img_adv, requires_grad=True) source_net.eval() with torch.no_grad(): img_diff = img - img_ori l_norm = torch.max(torch.abs(img_diff)).item() print('Perturbation: %f' % l_norm) mean_p += l_norm source_output = source_net(img).unsqueeze(0) for t in range(T): if t == 0: target_out_spikes_counter1 = target_net1( img).unsqueeze(0) target_out_spikes_counter2 = target_net2( encoder(img).float()).unsqueeze(0) else: target_out_spikes_counter1 += target_net1( img).unsqueeze(0) target_out_spikes_counter2 += target_net2( encoder(img).float()).unsqueeze(0) target_output1 = target_out_spikes_counter1 / T target_output2 = target_out_spikes_counter2 / T source_attack_flag = (source_output.max(1)[1] != label).float().sum().item() source_success_sum += source_attack_flag target_attack_flag1 = (target_output1.max(1)[1] != label).float().sum().item() target_success_sum1 += target_attack_flag1 target_attack_flag2 = (target_output2.max(1)[1] != label).float().sum().item() target_success_sum2 += target_attack_flag2 target_net1.reset_() target_net2.reset_() if source_attack_flag > 0.5: print('Source Attack Success') else: print('Source Attack Failure') if target_attack_flag1 > 0.5: print('Target Attack 1 Success') else: print('Target Attack 1 Failure') if target_attack_flag2 > 0.5: print('Target Attack 2 Success') else: print('Target Attack 2 Failure') if test_sum >= 250: mean_p /= 250 break else: for X, y in test_data_loader: for i in range(1, class_num): img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) img.requires_grad = True target_label = (label + i) % class_num test_sum += 1 source_net.train() for it in range(iter_num): output = source_net(img).unsqueeze(0) # loss = F.mse_loss(output, F.one_hot(target_label, class_num).float()) loss = F.cross_entropy(output, target_label) loss.backward() img_grad = torch.sign(img.grad.data) img_adv = clip_by_tensor(img - eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max) img = Variable(img_adv, requires_grad=True) source_net.eval() with torch.no_grad(): img_diff = img - img_ori l_norm = torch.max(torch.abs(img_diff)).item() print('Perturbation: %f' % l_norm) mean_p += l_norm source_output = source_net(img).unsqueeze(0) for t in range(T): if t == 0: target_out_spikes_counter1 = target_net1( img).unsqueeze(0) target_out_spikes_counter2 = target_net2( encoder(img).float()).unsqueeze(0) else: target_out_spikes_counter1 += target_net1( img).unsqueeze(0) target_out_spikes_counter2 += target_net2( encoder(img).float()).unsqueeze(0) target_output1 = target_out_spikes_counter1 / T target_output2 = target_out_spikes_counter2 / T source_attack_flag = (source_output.max( 1)[1] == target_label).float().sum().item() source_success_sum += source_attack_flag target_attack_flag1 = (target_output1.max( 1)[1] == target_label).float().sum().item() target_success_sum1 += target_attack_flag1 target_attack_flag2 = (target_output2.max( 1)[1] == target_label).float().sum().item() target_success_sum2 += target_attack_flag2 target_net1.reset_() target_net2.reset_() if source_attack_flag > 0.5: print('Source Attack Success') else: print('Source Attack Failure') if target_attack_flag1 > 0.5: print('Target Attack 1 Success') else: print('Target Attack 1 Failure') if target_attack_flag2 > 0.5: print('Target Attack 2 Success') else: print('Target Attack 2 Failure') if test_sum >= 270: mean_p /= 270 break print('Mean Perturbation: %.3f' % mean_p) print('source_success_sum: %d' % source_success_sum) print('target_success_1_sum: %d' % target_success_sum1) print('target_success_2_sum: %d' % target_success_sum2) print('test_sum: %d' % test_sum) print('source_success_rate: %.2f%%' % (100 * source_success_sum / test_sum)) print('target_success_1_rate: %.2f%%' % (100 * target_success_sum1 / test_sum)) print('target_success_2_rate: %.2f%%' % (100 * target_success_sum2 / test_sum))
def main(): gpu_list = input('输入使用的5个gpu,例如“0,1,2,0,3” ').split(',') dataset_dir = input('输入保存CIFAR10数据集的位置,例如“./” ') batch_size = int(input('输入batch_size,例如“64” ')) split_sizes = int(input('输入split_sizes,例如“16” ')) learning_rate = float(input('输入学习率,例如“1e-3” ')) T = int(input('输入仿真时长,例如“50” ')) tau = float(input('输入LIF神经元的时间常数tau,例如“100.0” ')) train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100” ')) log_dir = input('输入保存tensorboard日志文件的位置,例如“./” ') writer = SummaryWriter(log_dir) # 初始化数据加载器 train_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True), batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=batch_size, shuffle=True, drop_last=False) # 初始化网络 net = Net(gpu_list=gpu_list, tau=tau) # 使用Adam优化器 optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # 使用泊松编码器 encoder = encoding.PoissonEncoder() train_times = 0 for _ in range(train_epoch): net.train() for img, label in train_data_loader: label = label.to(net.gpu_list[-1]) optimizer.zero_grad() # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数 for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float(), split_sizes) else: out_spikes_counter += net( encoder(img).float(), split_sizes) # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率 out_spikes_counter_frequency = out_spikes_counter / T # 损失函数为输出层神经元的脉冲发放频率,与真实类别的交叉熵 # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0 loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() optimizer.step() # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的 net.reset_() # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果 correct_rate = (out_spikes_counter_frequency.max(1)[1] == label ).float().mean().item() writer.add_scalar('train_correct_rate', correct_rate, train_times) if train_times % 1024 == 0: print(gpu_list, dataset_dir, batch_size, split_sizes, learning_rate, T, tau, train_epoch, log_dir) print(sys.argv, 'train_times', train_times, 'train_correct_rate', correct_rate) train_times += 1 net.eval() with torch.no_grad(): # 每遍历一次全部数据集,就在测试集上测试一次 test_sum = 0 correct_sum = 0 for img, label in test_data_loader: label = label.to(net.gpu_list[-1]) for t in range(T): if t == 0: out_spikes_counter = net( encoder(img).float(), split_sizes) else: out_spikes_counter += net( encoder(img).float(), split_sizes) correct_sum += (out_spikes_counter.max(1)[1] == label ).float().sum().item() test_sum += label.numel() net.reset_() writer.add_scalar('test_correct_rate', correct_sum / test_sum, train_times)
def main(): # device = input('输入运行的设备,例如“cpu”或“cuda:0” ') # dataset_dir = input('输入保存MNIST数据集的位置,例如“./” ') # class_num = int(input('输入class_num,例如“10” ')) # T = int(input('输入仿真时长,例如“50” ')) # phase = input('输入算法阶段,例如“BIM” ') device = 'cuda:3' dataset_dir = '../../dataset/' class_num = 10 T = 50 phase = 'BIM' torch.cuda.empty_cache() encoder = encoding.PoissonEncoder() if phase == 'BIM': # model_path = input('输入模型文件路径,例如“./model.pth” ') # iter_num = int(input('输入对抗攻击的迭代次数,例如“25” ')) # gamma = float(input('输入GT的采样因子,例如“0.05” ')) # perturbation = float(input('输入扰动幅度,例如“4.0” ')) # attack_type = input('输入攻击类型,例如“UT/T” ') source_model_path = './models/mnist_spike_v1.pth' target_model_path = './models/mnist_spike_v2.pth' gamma = 0.05 iter_num = 50 perturbation = 3.1 attack_type = 'T' test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=1, shuffle=False, drop_last=False) source_net = Net().to(device) source_net.load_state_dict(torch.load(source_model_path)) target_net = Net().to(device) target_net.load_state_dict(torch.load(target_model_path)) target_net.eval() mean_p = 0.0 test_sum = 0 source_success_sum = 0 target_success_sum = 0 if attack_type == 'UT': for X, y in test_data_loader: img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) test_sum += 1 print('Img %d' % test_sum) source_net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = source_net(spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += source_net(spike).unsqueeze( 0) out_spikes_counter_frequency = out_spikes_counter / T # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float()) loss = F.cross_entropy(out_spikes_counter_frequency, label) loss.backward() ik = torch.zeros_like(spike).to(device) for spike in spike_train: if torch.max(torch.abs(spike.grad.data)) > 1e-32: # print('G2S Converter') grad_sign = torch.sign(spike.grad.data) grad_abs = torch.abs(spike.grad.data) grad_norm = (grad_abs - torch.min(grad_abs)) / ( torch.max(grad_abs) - torch.min(grad_abs)) grad_mask = torch.bernoulli(grad_norm) G2S = grad_sign * grad_mask G2S_trans = torch.clamp(G2S + spike, 0.0, 1.0) - spike ik += G2S_trans else: # print('Gradient Trigger') GT = torch.bernoulli( torch.ones_like(spike.grad.data) * gamma) GT_trans = (GT.bool() ^ spike.bool()).float() - spike ik += GT_trans ik /= T l2_norm = torch.norm(ik.view(ik.size()[0], -1), dim=1).item() # print('Perturbation: %f' % l2_norm) if l2_norm < perturbation: img = torch.clamp(img + ik, 0.0, 1.0) source_net.reset_() for p in spike_train: p.grad.data.zero_() for p in source_net.parameters(): p.grad.data.zero_() else: source_net.reset_() for p in spike_train: p.grad.data.zero_() for p in source_net.parameters(): p.grad.data.zero_() source_net.eval() with torch.no_grad(): img_diff = img - img_ori l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item() print('Total Perturbation: %f' % l2_norm) mean_p += l2_norm for t in range(T): if t == 0: source_out_spikes_counter = source_net( encoder(img).float()).unsqueeze(0) target_out_spikes_counter = target_net( encoder(img).float()).unsqueeze(0) else: source_out_spikes_counter += source_net( encoder(img).float()).unsqueeze(0) target_out_spikes_counter += target_net( encoder(img).float()).unsqueeze(0) source_out_spikes_counter_frequency = source_out_spikes_counter / T target_out_spikes_counter_frequency = target_out_spikes_counter / T source_attack_flag = (source_out_spikes_counter.max(1)[1] != label).float().sum().item() source_success_sum += source_attack_flag target_attack_flag = (target_out_spikes_counter.max(1)[1] != label).float().sum().item() target_success_sum += target_attack_flag source_net.reset_() target_net.reset_() if source_attack_flag > 0.5: print('Source Attack Success') else: print('Source Attack Failure') if target_attack_flag > 0.5: print('Target Attack Success') else: print('Target Attack Failure') if test_sum >= 250: mean_p /= 250 break else: for X, y in test_data_loader: for i in range(1, class_num): img, label = X.to(device), y.to(device) img_ori = torch.rand_like(img).copy_(img) target_label = (label + i) % class_num test_sum += 1 source_net.train() for it in range(iter_num): spike_train = [] for t in range(T): if t == 0: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter = source_net( spike).unsqueeze(0) else: spike = encoder(img).float() spike.requires_grad = True spike_train.append(spike) out_spikes_counter += source_net( spike).unsqueeze(0) out_spikes_counter_frequency = out_spikes_counter / T loss = F.mse_loss( out_spikes_counter_frequency, F.one_hot(target_label, class_num).float()) # loss = F.cross_entropy(out_spikes_counter_frequency, target_label) loss.backward() ik = torch.zeros_like(spike).to(device) for spike in spike_train: if torch.max(torch.abs(spike.grad.data)) > 1e-32: # print('G2S Converter') grad_sign = -torch.sign(spike.grad.data) grad_abs = torch.abs(spike.grad.data) grad_norm = (grad_abs - torch.min(grad_abs) ) / (torch.max(grad_abs) - torch.min(grad_abs)) grad_mask = torch.bernoulli(grad_norm) G2S = grad_sign * grad_mask G2S_trans = torch.clamp(G2S + spike, 0.0, 1.0) - spike ik += G2S_trans else: # print('Gradient Trigger') GT = torch.bernoulli( torch.ones_like(spike.grad.data) * gamma) GT_trans = (GT.bool() ^ spike.bool()).float() - spike ik += GT_trans ik /= T l2_norm = torch.norm(ik.view(ik.size()[0], -1), dim=1).item() # print('Perturbation: %f' % l2_norm) if l2_norm < perturbation: img = torch.clamp(img + ik, 0.0, 1.0) source_net.reset_() for p in spike_train: p.grad.data.zero_() for p in source_net.parameters(): p.grad.data.zero_() else: source_net.reset_() for p in spike_train: p.grad.data.zero_() for p in source_net.parameters(): p.grad.data.zero_() source_net.eval() with torch.no_grad(): img_diff = img - img_ori l2_norm = torch.norm(img_diff.view( img_diff.size()[0], -1), dim=1).item() print('Total Perturbation: %f' % l2_norm) mean_p += l2_norm for t in range(T): if t == 0: source_out_spikes_counter = source_net( encoder(img).float()).unsqueeze(0) target_out_spikes_counter = target_net( encoder(img).float()).unsqueeze(0) else: source_out_spikes_counter += source_net( encoder(img).float()).unsqueeze(0) target_out_spikes_counter += target_net( encoder(img).float()).unsqueeze(0) source_out_spikes_counter_frequency = source_out_spikes_counter / T target_out_spikes_counter_frequency = target_out_spikes_counter / T source_attack_flag = (source_out_spikes_counter.max( 1)[1] == target_label).float().sum().item() source_success_sum += source_attack_flag target_attack_flag = (target_out_spikes_counter.max( 1)[1] == target_label).float().sum().item() target_success_sum += target_attack_flag source_net.reset_() target_net.reset_() if source_attack_flag > 0.5: print('Source Attack Success') else: print('Source Attack Failure') if target_attack_flag > 0.5: print('Target Attack Success') else: print('Target Attack Failure') ''' samples = img.permute(0, 2, 3, 1).data.cpu().numpy() im = np.repeat(samples[0], 3, axis=2) im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item()) print(im_path) print(out_spikes_counter_frequency) plt.imsave(im_path, im) ''' if test_sum >= 270: mean_p /= 270 break print('Mean Perturbation: %.2f' % mean_p) print('source_success_sum: %d' % source_success_sum) print('target_success_sum: %d' % target_success_sum) print('test_sum: %d' % test_sum) print('source_success_rate: %.2f%%' % (100 * source_success_sum / test_sum)) print('target_success_rate: %.2f%%' % (100 * target_success_sum / test_sum))