Ejemplo n.º 1
0
def main():
    # device = input('输入运行的设备,例如“cpu”或“cuda:0”  ')
    # dataset_dir = input('输入保存MNIST数据集的位置,例如“./”  ')
    # class_num = int(input('输入class_num,例如“10”  '))
    # T = int(input('输入仿真时长,例如“50”  ')) 
    # phase = input('输入算法阶段,例如“BIM”  ')

    device = 'cuda:3'
    dataset_dir = '../../dataset/'
    class_num = 10
    T = 50
    phase = 'BIM'

    torch.cuda.empty_cache()

    encoder = encoding.PoissonEncoder()

    if phase == 'BIM':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # eta = float(input('输入对抗攻击学习率,例如“0.05”  '))
        # attack_type = input('输入攻击类型,例如“UT/T”  ')
        # clip_eps = float(input('输入截断eps,例如“0.01”  '))

        source_model_path = './models/cifar10_spike_v1.pth'
        target_model_path = './models/cifar10_spike_v2.pth'
        
        iter_num = 25
        eta = 0.03
        attack_type = 'UT'

        clip_eps = 0.35

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(
                root=dataset_dir,
                train=False,
                transform=transform_test,
                download=True),
            batch_size=1,
            shuffle=False,
            drop_last=False)

        p_max = transform_test(np.ones((32, 32, 3))).to(device)
        p_min = transform_test(np.zeros((32, 32, 3))).to(device)

        source_net = Net().to(device)
        source_net.load_state_dict(torch.load(source_model_path))

        target_net = Net().to(device)
        target_net.load_state_dict(torch.load(target_model_path))

        target_net.eval()

        mean_p = 0.0
        test_sum = 0
        source_success_sum = 0
        target_success_sum = 0

        if attack_type == 'UT':
            for X, y in test_data_loader:
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)

                test_sum += 1

                print('Img %d' % test_sum)

                source_net.train()

                for it in range(iter_num):
                    spike_train = []

                    for t in range(T):
                        if t == 0:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter = source_net(spike).unsqueeze(0)
                        else:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter += source_net(spike).unsqueeze(0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float())
                    loss = F.cross_entropy(out_spikes_counter_frequency, label)

                    loss.backward()

                    rate = torch.zeros_like(spike).to(device)

                    for spike in spike_train:
                        rate += spike.grad.data

                    img_adv = clip_by_tensor(img + eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max)

                    source_net.reset_()

                    for p in spike_train:
                        p.grad.data.zero_()
                    for p in source_net.parameters():
                        p.grad.data.zero_()

                source_net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l_norm = torch.max(torch.abs(img_diff)).item()
                    print('Perturbation: %f' % l_norm)

                    mean_p += l_norm

                    for t in range(T):
                        if t == 0:
                            source_out_spikes_counter = source_net(encoder(img).float()).unsqueeze(0)
                            target_out_spikes_counter = target_net(encoder(img).float()).unsqueeze(0)
                        else:
                            source_out_spikes_counter += source_net(encoder(img).float()).unsqueeze(0)
                            target_out_spikes_counter += target_net(encoder(img).float()).unsqueeze(0)

                    source_out_spikes_counter_frequency = source_out_spikes_counter / T
                    target_out_spikes_counter_frequency = target_out_spikes_counter / T

                    source_attack_flag = (source_out_spikes_counter.max(1)[1] != label).float().sum().item()
                    source_success_sum += source_attack_flag

                    target_attack_flag = (target_out_spikes_counter.max(1)[1] != label).float().sum().item()
                    target_success_sum += target_attack_flag

                    source_net.reset_()
                    target_net.reset_()

                    if source_attack_flag > 0.5:
                        print('Source Attack Success')
                    else:
                        print('Source Attack Failure')

                    if target_attack_flag > 0.5:
                        print('Target Attack Success')
                    else:
                        print('Target Attack Failure')

                if test_sum >= 250:
                    mean_p /= 250
                    break 
        else:
            for X, y in test_data_loader:
                for i in range(1, class_num):
                    img, label = X.to(device), y.to(device)
                    img_ori = torch.rand_like(img).copy_(img)

                    target_label = (label + i) % class_num

                    test_sum += 1

                    source_net.train()

                    for it in range(iter_num):
                        spike_train = []

                        for t in range(T):
                            if t == 0:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter = source_net(spike).unsqueeze(0)
                            else:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter += source_net(spike).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float())
                        loss = F.cross_entropy(out_spikes_counter_frequency, target_label)
                        
                        loss.backward()

                        rate = torch.zeros_like(spike).to(device)

                        for spike in spike_train:
                            rate += spike.grad.data

                        img_grad = torch.sign(rate)

                        img_adv = clip_by_tensor(img - eta * img_grad, img_ori - clip_eps, img_ori + clip_eps, p_min, p_max)

                        source_net.reset_()

                        for p in spike_train:
                            p.grad.data.zero_()
                        for p in source_net.parameters():
                            p.grad.data.zero_()

                    source_net.eval()

                    with torch.no_grad():
                        img_diff = img - img_ori

                        l_norm = torch.max(torch.abs(img_diff)).item()
                        print('Perturbation: %f' % l_norm)

                        mean_p += l_norm

                        for t in range(T):
                            if t == 0:
                                source_out_spikes_counter = source_net(encoder(img).float()).unsqueeze(0)
                                target_out_spikes_counter = target_net(encoder(img).float()).unsqueeze(0)
                            else:
                                source_out_spikes_counter += source_net(encoder(img).float()).unsqueeze(0)
                                target_out_spikes_counter += target_net(encoder(img).float()).unsqueeze(0)

                        source_out_spikes_counter_frequency = source_out_spikes_counter / T
                        target_out_spikes_counter_frequency = target_out_spikes_counter / T

                        source_attack_flag = (source_out_spikes_counter.max(1)[1] == target_label).float().sum().item()
                        source_success_sum += source_attack_flag

                        target_attack_flag = (target_out_spikes_counter.max(1)[1] == target_label).float().sum().item()
                        target_success_sum += target_attack_flag

                        source_net.reset_()
                        target_net.reset_()

                        if source_attack_flag > 0.5:
                            print('Source Attack Success')
                        else:
                            print('Source Attack Failure')

                        if target_attack_flag > 0.5:
                            print('Target Attack Success')
                        else:
                            print('Target Attack Failure')

                        '''
                        samples = img.permute(0, 2, 3, 1).data.cpu().numpy()

                        im = np.repeat(samples[0], 3, axis=2)
                        im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item())
                        print(im_path)
                        print(out_spikes_counter_frequency)
                        plt.imsave(im_path, im)
                        '''

                if test_sum >= 270:
                    mean_p /= 270
                    break

        print('Mean Perturbation: %.3f' % mean_p)
        print('source_success_sum: %d' % source_success_sum)
        print('target_success_sum: %d' % target_success_sum)
        print('test_sum: %d' % test_sum)
        print('source_success_rate: %.2f%%' % (100 * source_success_sum / test_sum))
        print('target_success_rate: %.2f%%' % (100 * target_success_sum / test_sum))
def main():
    # device = input('输入运行的设备,例如“cpu”或“cuda:0”  ')
    # dataset_dir = input('输入保存MNIST数据集的位置,例如“./”  ')
    # class_num = int(input('输入class_num,例如“10”  '))
    # lr = float(input('输入学习率,例如“1e-3”  '))
    # T = int(input('输入仿真时长,例如“50”  '))
    # phase = input('输入算法阶段,例如“train/BIM”  ')

    device = 'cuda:1'
    dataset_dir = '../../dataset/'
    class_num = 10
    T = 50
    phase = 'BIM'

    torch.cuda.empty_cache()

    encoder = encoding.PoissonEncoder()

    if phase == 'BIM':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # eta = float(input('输入对抗攻击学习率,例如“0.05”  '))
        # attack_type = input('输入攻击类型,例如“UT/T”  ')
        # clip_flag = bool(input('输入是否使用截断,例如“True/False”  '))
        # clip_eps = float(input('输入截断eps,例如“0.01”  '))

        model_path = './models/mnist_spike_v1.pth'
        iter_num = 100
        eta = 0.02
        attack_type = 'T'

        clip_flag = True
        clip_eps = 0.4

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.MNIST(
                root=dataset_dir,
                train=False,
                transform=torchvision.transforms.ToTensor(),
                download=True),
            batch_size=1,
            shuffle=False,
            drop_last=False)

        net = Net().to(device)
        net.load_state_dict(torch.load(model_path))

        mean_p = 0.0
        test_sum = 0
        success_sum = 0

        if attack_type == 'UT':
            for X, y in test_data_loader:
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)

                test_sum += 1

                print('Img %d' % test_sum)

                net.train()

                for it in range(iter_num):
                    spike_train = []

                    for t in range(T):
                        if t == 0:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter = net(spike).unsqueeze(0)
                        else:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter += net(spike).unsqueeze(0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float())
                    loss = F.cross_entropy(out_spikes_counter_frequency, label)

                    loss.backward()

                    rate = torch.zeros_like(spike).to(device)

                    for spike in spike_train:
                        rate += spike.grad.data

                    img_grad = torch.sign(rate)

                    if clip_flag:
                        img = clip_by_tensor(img + eta * img_grad,
                                             img_ori - clip_eps,
                                             img_ori + clip_eps)
                    else:
                        img = torch.clamp(img + eta * img_grad, 0.0, 1.0)

                    net.reset_()

                    for p in spike_train:
                        p.grad.data.zero_()
                    for p in net.parameters():
                        p.grad.data.zero_()

                net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1),
                                         dim=1).item()
                    print('Total Perturbation: %f' % l2_norm)

                    mean_p += l2_norm

                    for t in range(T):
                        if t == 0:
                            out_spikes_counter = net(
                                encoder(img).float()).unsqueeze(0)
                        else:
                            out_spikes_counter += net(
                                encoder(img).float()).unsqueeze(0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    attack_flag = (out_spikes_counter.max(1)[1] !=
                                   label).float().sum().item()
                    success_sum += attack_flag

                    if attack_flag > 0.5:
                        print('Attack Success')
                    else:
                        print('Attack Failure')

                if test_sum >= 250:
                    mean_p /= 250
                    break
        else:
            for X, y in test_data_loader:
                for i in range(1, class_num):
                    img, label = X.to(device), y.to(device)
                    img_ori = torch.rand_like(img).copy_(img)

                    target_label = (label + i) % class_num

                    test_sum += 1

                    net.train()

                    for it in range(iter_num):
                        spike_train = []

                        for t in range(T):
                            if t == 0:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter = net(spike).unsqueeze(0)
                            else:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter += net(spike).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float())
                        loss = F.cross_entropy(out_spikes_counter_frequency,
                                               target_label)

                        loss.backward()

                        rate = torch.zeros_like(spike).to(device)

                        for spike in spike_train:
                            rate += spike.grad.data

                        img_grad = torch.sign(rate)

                        if clip_flag:
                            img = clip_by_tensor(img - eta * img_grad,
                                                 img_ori - clip_eps,
                                                 img_ori + clip_eps)
                        else:
                            img = torch.clamp(img - eta * img_grad, 0.0, 1.0)

                        net.reset_()

                        for p in spike_train:
                            p.grad.data.zero_()
                        for p in net.parameters():
                            p.grad.data.zero_()

                    net.eval()

                    with torch.no_grad():
                        img_diff = img - img_ori

                        l2_norm = torch.norm(img_diff.view(
                            img_diff.size()[0], -1),
                                             dim=1).item()
                        print('Total Perturbation: %f' % l2_norm)

                        mean_p += l2_norm

                        for t in range(T):
                            if t == 0:
                                out_spikes_counter = net(
                                    encoder(img).float()).unsqueeze(0)
                            else:
                                out_spikes_counter += net(
                                    encoder(img).float()).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        attack_flag = (out_spikes_counter.max(1)[1] ==
                                       target_label).float().sum().item()
                        success_sum += attack_flag

                        if attack_flag > 0.5:
                            print('Attack Success')
                        else:
                            print('Attack Failure')
                        '''
                        samples = img.permute(0, 2, 3, 1).data.cpu().numpy()

                        im = np.repeat(samples[0], 3, axis=2)
                        im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item())
                        print(im_path)
                        print(out_spikes_counter_frequency)
                        plt.imsave(im_path, im)
                        '''

                if test_sum >= 270:
                    mean_p /= 270
                    break

        print('Mean Perturbation: %.2f' % mean_p)
        print('success_sum: %d' % success_sum)
        print('test_sum: %d' % test_sum)
        print('success_rate: %.2f%%' % (100 * success_sum / test_sum))
Ejemplo n.º 3
0
def main():
    # device = input('输入运行的设备,例如“cpu”或“cuda:0”  ')
    # dataset_dir = input('输入保存MNIST数据集的位置,例如“./”  ')
    # class_num = int(input('输入class_num,例如“10”  '))
    # lr = float(input('输入学习率,例如“1e-3”  '))
    # T = int(input('输入仿真时长,例如“50”  '))
    # phase = input('输入算法阶段,例如“train/BIM”  ')

    device = 'cuda:3'
    dataset_dir = '../../dataset/'
    class_num = 10
    lr = 1e-4
    T = 8
    phase = 'train'

    torch.cuda.empty_cache()

    encoder = encoding.PoissonEncoder()

    if phase == 'train':
        # model_dir = input('输入保存模型文件的位置,例如“./”  ')
        # batch_size = int(input('输入batch_size,例如“64”  '))
        # train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100”  '))
        # log_dir = input('输入保存tensorboard日志文件的位置,例如“./”  ')

        model_dir = './models/'
        batch_size = 64
        train_epoch = 9999999
        log_dir = './logs/'

        writer = SummaryWriter(log_dir)

        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(root=dataset_dir,
                                                 train=True,
                                                 transform=transform_train,
                                                 download=True),
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(root=dataset_dir,
                                                 train=False,
                                                 transform=transform_test,
                                                 download=True),
            batch_size=batch_size,
            shuffle=True,
            drop_last=False)

        net = Net().to(device)

        optimizer = torch.optim.Adam(net.parameters(), lr=lr)

        train_times = 0
        best_epoch = 0
        max_correct_sum = 0

        for epoch in range(1, train_epoch + 1):
            net.train()
            for X, y in train_data_loader:
                img, label = X.to(device), y.to(device)

                optimizer.zero_grad()

                for t in range(T):
                    if t == 0:
                        out_spikes_counter = net(encoder(img).float())
                    else:
                        out_spikes_counter += net(encoder(img).float())

                out_spikes_counter_frequency = out_spikes_counter / T

                loss = F.mse_loss(out_spikes_counter_frequency,
                                  F.one_hot(label, class_num).float())
                # loss = F.cross_entropy(out_spikes_counter_frequency, label)

                loss.backward()
                optimizer.step()

                net.reset_()

                correct_rate = (out_spikes_counter_frequency.max(1)[1] == label
                                ).float().mean().item()
                writer.add_scalar('train_correct_rate', correct_rate,
                                  train_times)
                # if train_times % 1024 == 0:
                #     print(device, dataset_dir, batch_size, lr, T, train_epoch, log_dir)
                #     print(sys.argv, 'train_times', train_times, 'train_correct_rate', correct_rate)
                train_times += 1

            net.eval()

            with torch.no_grad():
                test_sum = 0
                correct_sum = 0
                for X, y in test_data_loader:
                    img, label = X.to(device), y.to(device)
                    for t in range(T):
                        if t == 0:
                            out_spikes_counter = net(encoder(img).float())
                        else:
                            out_spikes_counter = net(encoder(img).float())

                    correct_sum += (out_spikes_counter.max(1)[1] == label
                                    ).float().sum().item()
                    test_sum += label.numel()
                    net.reset_()

                writer.add_scalar('test_correct_rate', correct_sum / test_sum,
                                  train_times)

                print('epoch', epoch, 'test_correct_rate',
                      correct_sum / test_sum)

                if correct_sum > max_correct_sum:
                    max_correct_sum = correct_sum
                    torch.save(net.state_dict(),
                               model_dir + 'spike_best_%d.pth' % (epoch))
                    if best_epoch > 0:
                        os.system('rm %sspike_best_%d.pth' %
                                  (model_dir, best_epoch))
                    best_epoch = epoch

    elif phase == 'BIM':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # eta = float(input('输入对抗攻击学习率,例如“0.05”  '))
        # attack_type = input('输入攻击类型,例如“UT/T”  ')
        # clip_eps = float(input('输入截断eps,例如“0.01”  '))

        model_path = './models/cifar10_spike_v1.pth'
        iter_num = 50
        eta = 0.03
        attack_type = 'UT'

        clip_eps = 0.6

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(root=dataset_dir,
                                                 train=False,
                                                 transform=transform_test,
                                                 download=True),
            batch_size=1,
            shuffle=False,
            drop_last=False)

        p_max = transform_test(np.ones((32, 32, 3))).to(device)
        p_min = transform_test(np.zeros((32, 32, 3))).to(device)

        net = Net().to(device)
        net.load_state_dict(torch.load(model_path))

        mean_p = 0.0
        test_sum = 0
        success_sum = 0

        if attack_type == 'UT':
            for X, y in test_data_loader:
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)

                test_sum += 1

                print('Img %d' % test_sum)

                net.train()

                for it in range(iter_num):
                    spike_train = []

                    for t in range(T):
                        if t == 0:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter = net(spike).unsqueeze(0)
                        else:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter += net(spike).unsqueeze(0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float())
                    loss = F.cross_entropy(out_spikes_counter_frequency, label)

                    loss.backward()

                    rate = torch.zeros_like(spike).to(device)

                    for spike in spike_train:
                        rate += spike.grad.data

                    img_grad = torch.sign(rate)

                    img_adv = clip_by_tensor(img + eta * img_grad,
                                             img_ori - clip_eps,
                                             img_ori + clip_eps, p_min, p_max)

                    net.reset_()

                    for p in spike_train:
                        p.grad.data.zero_()
                    for p in net.parameters():
                        p.grad.data.zero_()

                net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l_norm = torch.max(torch.abs(img_diff)).item()
                    print('Perturbation: %f' % l_norm)

                    mean_p += l_norm

                    for t in range(T):
                        if t == 0:
                            out_spikes_counter = net(
                                encoder(img).float()).unsqueeze(0)
                        else:
                            out_spikes_counter += net(
                                encoder(img).float()).unsqueeze(0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    attack_flag = (out_spikes_counter.max(1)[1] !=
                                   label).float().sum().item()
                    success_sum += attack_flag

                    if attack_flag > 0.5:
                        print('Attack Success')
                    else:
                        print('Attack Failure')

                if test_sum >= 250:
                    mean_p /= 250
                    break
        else:
            for X, y in test_data_loader:
                for i in range(1, class_num):
                    img, label = X.to(device), y.to(device)
                    img_ori = torch.rand_like(img).copy_(img)

                    target_label = (label + i) % class_num

                    test_sum += 1

                    net.train()

                    for it in range(iter_num):
                        spike_train = []

                        for t in range(T):
                            if t == 0:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter = net(spike).unsqueeze(0)
                            else:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter += net(spike).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float())
                        loss = F.cross_entropy(out_spikes_counter_frequency,
                                               target_label)

                        loss.backward()

                        rate = torch.zeros_like(spike).to(device)

                        for spike in spike_train:
                            rate += spike.grad.data

                        img_grad = torch.sign(rate)

                        img_adv = clip_by_tensor(img - eta * img_grad,
                                                 img_ori - clip_eps,
                                                 img_ori + clip_eps, p_min,
                                                 p_max)

                        net.reset_()

                        for p in spike_train:
                            p.grad.data.zero_()
                        for p in net.parameters():
                            p.grad.data.zero_()

                    net.eval()

                    with torch.no_grad():
                        img_diff = img - img_ori

                        l_norm = torch.max(torch.abs(img_diff)).item()
                        print('Perturbation: %f' % l_norm)

                        mean_p += l_norm

                        for t in range(T):
                            if t == 0:
                                out_spikes_counter = net(
                                    encoder(img).float()).unsqueeze(0)
                            else:
                                out_spikes_counter += net(
                                    encoder(img).float()).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        attack_flag = (out_spikes_counter.max(1)[1] ==
                                       target_label).float().sum().item()
                        success_sum += attack_flag

                        if attack_flag > 0.5:
                            print('Attack Success')
                        else:
                            print('Attack Failure')
                        '''
                        samples = img.permute(0, 2, 3, 1).data.cpu().numpy()

                        im = np.repeat(samples[0], 3, axis=2)
                        im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item())
                        print(im_path)
                        print(out_spikes_counter_frequency)
                        plt.imsave(im_path, im)
                        '''

                if test_sum >= 270:
                    mean_p /= 270
                    break

        print('Mean Perturbation: %.3f' % mean_p)
        print('success_sum: %d' % success_sum)
        print('test_sum: %d' % test_sum)
        print('success_rate: %.2f%%' % (100 * success_sum / test_sum))
def main():
    # device = input('输入运行的设备,例如“cpu”或“cuda:0”  ')
    # dataset_dir = input('输入保存MNIST数据集的位置,例如“./”  ')
    # class_num = int(input('输入class_num,例如“10”  '))
    # lr = float(input('输入学习率,例如“1e-3”  '))
    # T = int(input('输入仿真时长,例如“50”  ')) 
    # phase = input('输入算法阶段,例如“train/BIM”  ')

    device = 'cuda:1'
    dataset_dir = '../../dataset/'
    class_num = 10
    lr = 1e-4
    T = 50
    phase = 'BIM'

    torch.cuda.empty_cache()

    encoder = encoding.PoissonEncoder()

    if phase == 'train':
        # model_dir = input('输入保存模型文件的位置,例如“./”  ')
        # batch_size = int(input('输入batch_size,例如“64”  '))
        # train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100”  '))
        # log_dir = input('输入保存tensorboard日志文件的位置,例如“./”  ')

        model_dir = './models/'
        batch_size = 64
        train_epoch = 9999999
        log_dir = './logs/'

        writer = SummaryWriter(log_dir)

        train_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.MNIST(
                root=dataset_dir,
                train=True,
                transform=torchvision.transforms.ToTensor(),
                download=True),
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.MNIST(
                root=dataset_dir,
                train=False,
                transform=torchvision.transforms.ToTensor(),
                download=True),
            batch_size=batch_size,
            shuffle=True,
            drop_last=False)

        net = Net().to(device)
        
        optimizer = torch.optim.Adam(net.parameters(), lr=lr)
        
        train_times = 0
        best_epoch = 0
        max_correct_sum = 0

        for epoch in range(1, train_epoch + 1):
            net.train()
            for X, y in train_data_loader:
                img, label = X.to(device), y.to(device)
                
                optimizer.zero_grad()

                for t in range(T):
                    if t == 0:
                        out_spikes_counter = net(encoder(img).float())
                    else:
                        out_spikes_counter += net(encoder(img).float())

                out_spikes_counter_frequency = out_spikes_counter / T

                loss =  F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float())
                # loss = F.cross_entropy(out_spikes_counter_frequency, label)
                
                loss.backward()
                optimizer.step()
                
                net.reset_()

                correct_rate = (out_spikes_counter_frequency.max(1)[1] == label).float().mean().item()
                writer.add_scalar('train_correct_rate', correct_rate, train_times)
                # if train_times % 1024 == 0:
                #     print(device, dataset_dir, batch_size, lr, T, train_epoch, log_dir)
                #     print(sys.argv, 'train_times', train_times, 'train_correct_rate', correct_rate)
                train_times += 1

            net.eval()
            
            with torch.no_grad():
                test_sum = 0
                correct_sum = 0
                for X, y in test_data_loader:
                    img, label = X.to(device), y.to(device)
                    for t in range(T):
                        if t == 0:
                            out_spikes_counter = net(encoder(img).float())
                        else:
                            out_spikes_counter = net(encoder(img).float())

                    correct_sum += (out_spikes_counter.max(1)[1] == label).float().sum().item()
                    test_sum += label.numel()
                    net.reset_()

                writer.add_scalar('test_correct_rate', correct_sum / test_sum, train_times)

                print('epoch', epoch, 'test_correct_rate', correct_sum / test_sum)

                if correct_sum > max_correct_sum:
                    max_correct_sum = correct_sum
                    torch.save(net.state_dict(), model_dir + 'spike_best_%d.pth' % (epoch))
                    if best_epoch > 0:
                        os.system('rm %sspike_best_%d.pth' % (model_dir, best_epoch))
                    best_epoch = epoch

    elif phase == 'BIM':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # gamma = float(input('输入GT的采样因子,例如“0.05”  '))
        # perturbation = float(input('输入扰动幅度,例如“4.0”  '))
        # attack_type = input('输入攻击类型,例如“UT/T”  ')

        model_path = './models/mnist_spike_v1.pth'
        gamma = 0.05
        iter_num = 50
        perturbation = 3.1
        attack_type = 'T'

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.MNIST(
                root=dataset_dir,
                train=False,
                transform=torchvision.transforms.ToTensor(),
                download=True),
            batch_size=1,
            shuffle=False,
            drop_last=False)

        net = Net().to(device)
        net.load_state_dict(torch.load(model_path))

        mean_p = 0.0
        test_sum = 0
        success_sum = 0

        if attack_type == 'UT':
            for X, y in test_data_loader:
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)

                test_sum += 1

                print('Img %d' % test_sum)

                net.train()

                for it in range(iter_num):
                    spike_train = []

                    for t in range(T):
                        if t == 0:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter = net(spike).unsqueeze(0)
                        else:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter += net(spike).unsqueeze(0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float())
                    loss = F.cross_entropy(out_spikes_counter_frequency, label)

                    loss.backward()

                    ik = torch.zeros_like(spike).to(device)

                    for spike in spike_train:
                        if torch.max(torch.abs(spike.grad.data)) > 1e-32:
                            # print('G2S Converter')

                            grad_sign = torch.sign(spike.grad.data)
                            grad_abs = torch.abs(spike.grad.data)
                            grad_norm = (grad_abs - torch.min(grad_abs)) / (torch.max(grad_abs) - torch.min(grad_abs))
                            grad_mask = torch.bernoulli(grad_norm)
                            G2S = grad_sign * grad_mask
                            G2S_trans = torch.clamp(G2S + spike, 0.0, 1.0) - spike

                            ik += G2S_trans

                        else:
                            # print('Gradient Trigger')

                            GT = torch.bernoulli(torch.ones_like(spike.grad.data) * gamma)
                            GT_trans = (GT.bool() ^ spike.bool()).float() - spike

                            ik += GT_trans

                    ik /= T

                    l2_norm = torch.norm(ik.view(ik.size()[0], -1), dim=1).item()
                    # print('Perturbation: %f' % l2_norm)

                    if l2_norm < perturbation:
                        img = torch.clamp(img + ik, 0.0, 1.0)

                        net.reset_()

                        for p in spike_train:
                            p.grad.data.zero_()
                        for p in net.parameters():
                            p.grad.data.zero_()

                    else:
                        net.reset_()

                        for p in spike_train:
                            p.grad.data.zero_()
                        for p in net.parameters():
                            p.grad.data.zero_()

                net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item()
                    print('Total Perturbation: %f' % l2_norm)

                    mean_p += l2_norm

                    for t in range(T):
                        if t == 0:
                            out_spikes_counter = net(encoder(img).float()).unsqueeze(0)
                        else:
                            out_spikes_counter += net(encoder(img).float()).unsqueeze(0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    attack_flag = (out_spikes_counter.max(1)[1] != label).float().sum().item()
                    success_sum += attack_flag

                    if attack_flag > 0.5:
                        print('Attack Success')
                    else:
                        print('Attack Failure')

                if test_sum >= 250:
                    mean_p /= 250
                    break 
        else:
            for X, y in test_data_loader:
                for i in range(1, class_num):
                    img, label = X.to(device), y.to(device)
                    img_ori = torch.rand_like(img).copy_(img)

                    target_label = (label + i) % class_num

                    test_sum += 1

                    net.train()

                    for it in range(iter_num):
                        spike_train = []

                        for t in range(T):
                            if t == 0:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter = net(spike).unsqueeze(0)
                            else:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter += net(spike).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float())
                        # loss = F.cross_entropy(out_spikes_counter_frequency, target_label)

                        loss.backward()

                        ik = torch.zeros_like(spike).to(device)

                        for spike in spike_train:
                            if torch.max(torch.abs(spike.grad.data)) > 1e-32:
                                # print('G2S Converter')

                                grad_sign = -torch.sign(spike.grad.data)
                                grad_abs = torch.abs(spike.grad.data)
                                grad_norm = (grad_abs - torch.min(grad_abs)) / (torch.max(grad_abs) - torch.min(grad_abs))
                                grad_mask = torch.bernoulli(grad_norm)
                                G2S = grad_sign * grad_mask
                                G2S_trans = torch.clamp(G2S + spike, 0.0, 1.0) - spike

                                ik += G2S_trans

                            else:
                                # print('Gradient Trigger')

                                GT = torch.bernoulli(torch.ones_like(spike.grad.data) * gamma)
                                GT_trans = (GT.bool() ^ spike.bool()).float() - spike

                                ik += GT_trans

                        ik /= T

                        l2_norm = torch.norm(ik.view(ik.size()[0], -1), dim=1).item()
                        # print('Perturbation: %f' % l2_norm)

                        if l2_norm < perturbation:
                            img = torch.clamp(img + ik, 0.0, 1.0)

                            net.reset_()

                            for p in spike_train:
                                p.grad.data.zero_()
                            for p in net.parameters():
                                p.grad.data.zero_()

                        else:
                            net.reset_()

                            for p in spike_train:
                                p.grad.data.zero_()
                            for p in net.parameters():
                                p.grad.data.zero_()

                    net.eval()

                    with torch.no_grad():
                        img_diff = img - img_ori

                        l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1), dim=1).item()
                        print('Total Perturbation: %f' % l2_norm)

                        mean_p += l2_norm

                        for t in range(T):
                            if t == 0:
                                out_spikes_counter = net(encoder(img).float()).unsqueeze(0)
                            else:
                                out_spikes_counter += net(encoder(img).float()).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        attack_flag = (out_spikes_counter.max(1)[1] == target_label).float().sum().item()
                        success_sum += attack_flag

                        if attack_flag > 0.5:
                            print('Attack Success')
                        else:
                            print('Attack Failure')

                        '''
                        samples = img.permute(0, 2, 3, 1).data.cpu().numpy()

                        im = np.repeat(samples[0], 3, axis=2)
                        im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item())
                        print(im_path)
                        print(out_spikes_counter_frequency)
                        plt.imsave(im_path, im)
                        '''

                if test_sum >= 270:
                    mean_p /= 270
                    break

        print('Mean Perturbation: %.2f' % mean_p)
        print('success_sum: %d' % success_sum)
        print('test_sum: %d' % test_sum)
        print('success_rate: %.2f%%' % (100 * success_sum / test_sum))  
def main():
    # device = input('输入运行的设备,例如“cpu”或“cuda:0”  ')
    # dataset_dir = input('输入保存MNIST数据集的位置,例如“./”  ')
    # class_num = int(input('输入class_num,例如“10”  '))
    # T = int(input('输入仿真时长,例如“50”  '))
    # phase = input('输入算法阶段,例如“BIM”  ')

    device = 'cuda:3'
    dataset_dir = '../../dataset/'
    class_num = 10
    T = 50
    phase = 'BIM'

    torch.cuda.empty_cache()

    encoder = encoding.PoissonEncoder()

    if phase == 'BIM':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # eta = float(input('输入对抗攻击学习率,例如“0.05”  '))
        # attack_type = input('输入攻击类型,例如“UT/T”  ')
        # clip_flag = bool(input('输入是否使用截断,例如“True/False”  '))
        # clip_eps = float(input('输入截断eps,例如“0.01”  '))

        source_model_path = './models/mnist_img_v1.pth'
        target_model_path1 = './models/mnist_ann_v1.pth'
        target_model_path2 = './models/mnist_spike_v1.pth'

        iter_num = 25
        eta = 0.02
        attack_type = 'T'

        clip_flag = True
        clip_eps = 0.3

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.MNIST(
                root=dataset_dir,
                train=False,
                transform=torchvision.transforms.ToTensor(),
                download=True),
            batch_size=1,
            shuffle=False,
            drop_last=False)

        source_net = SNN_Net().to(device)
        source_net.load_state_dict(torch.load(source_model_path))

        target_net1 = ANN_Net().to(device)
        target_net1.load_state_dict(torch.load(target_model_path1))

        target_net2 = SNN_Net().to(device)
        target_net2.load_state_dict(torch.load(target_model_path2))

        target_net1.eval()
        target_net2.eval()

        mean_p = 0.0
        test_sum = 0
        source_success_sum = 0
        target_success_sum1 = 0
        target_success_sum2 = 0

        if attack_type == 'UT':
            for X, y in test_data_loader:
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)
                img.requires_grad = True

                test_sum += 1

                print('Img %d' % test_sum)

                source_net.train()

                for it in range(iter_num):
                    for t in range(T):
                        if t == 0:
                            out_spikes_counter = source_net(img).unsqueeze(0)
                        else:
                            out_spikes_counter += source_net(img).unsqueeze(0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float())
                    loss = F.cross_entropy(out_spikes_counter_frequency, label)

                    loss.backward()

                    img_grad = torch.sign(img.grad.data)

                    img_adv = None

                    if clip_flag:
                        img_adv = clip_by_tensor(img + eta * img_grad,
                                                 img_ori - clip_eps,
                                                 img_ori + clip_eps)
                    else:
                        img_adv = torch.clamp(img + eta * img_grad, 0.0, 1.0)

                    img = Variable(img_adv, requires_grad=True)

                    source_net.reset_()

                source_net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1),
                                         dim=1).item()
                    print('Perturbation: %f' % l2_norm)

                    mean_p += l2_norm

                    target_output1 = target_net1(img).unsqueeze(0)

                    for t in range(T):
                        if t == 0:
                            source_out_spikes_counter = source_net(
                                img).unsqueeze(0)
                            target_out_spikes_counter2 = target_net2(
                                encoder(img).float()).unsqueeze(0)
                        else:
                            source_out_spikes_counter += source_net(
                                img).unsqueeze(0)
                            target_out_spikes_counter2 += target_net2(
                                encoder(img).float()).unsqueeze(0)

                    source_output = source_out_spikes_counter / T
                    target_output2 = target_out_spikes_counter2 / T

                    source_attack_flag = (source_output.max(1)[1] !=
                                          label).float().sum().item()
                    source_success_sum += source_attack_flag

                    target_attack_flag1 = (target_output1.max(1)[1] !=
                                           label).float().sum().item()
                    target_success_sum1 += target_attack_flag1

                    target_attack_flag2 = (target_output2.max(1)[1] !=
                                           label).float().sum().item()
                    target_success_sum2 += target_attack_flag2

                    source_net.reset_()
                    target_net2.reset_()

                    if source_attack_flag > 0.5:
                        print('Source Attack Success')
                    else:
                        print('Source Attack Failure')

                    if target_attack_flag1 > 0.5:
                        print('Target Attack 1 Success')
                    else:
                        print('Target Attack 1 Failure')

                    if target_attack_flag2 > 0.5:
                        print('Target Attack 2 Success')
                    else:
                        print('Target Attack 2 Failure')

                if test_sum >= 250:
                    mean_p /= 250
                    break
        else:
            for X, y in test_data_loader:
                for i in range(1, class_num):
                    img, label = X.to(device), y.to(device)
                    img_ori = torch.rand_like(img).copy_(img)
                    img.requires_grad = True

                    target_label = (label + i) % class_num

                    test_sum += 1

                    source_net.train()

                    for it in range(iter_num):
                        for t in range(T):
                            if t == 0:
                                out_spikes_counter = source_net(img).unsqueeze(
                                    0)
                            else:
                                out_spikes_counter += source_net(
                                    img).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(target_label, class_num).float())
                        loss = F.cross_entropy(out_spikes_counter_frequency,
                                               target_label)

                        loss.backward()

                        img_grad = torch.sign(img.grad.data)

                        img_adv = None

                        if clip_flag:
                            img_adv = clip_by_tensor(img - eta * img_grad,
                                                     img_ori - clip_eps,
                                                     img_ori + clip_eps)
                        else:
                            img_adv = torch.clamp(img - eta * img_grad, 0.0,
                                                  1.0)

                        img = Variable(img_adv, requires_grad=True)

                        source_net.reset_()

                    source_net.eval()

                    with torch.no_grad():
                        img_diff = img - img_ori

                        l2_norm = torch.norm(img_diff.view(
                            img_diff.size()[0], -1),
                                             dim=1).item()
                        print('Perturbation: %f' % l2_norm)

                        mean_p += l2_norm

                        target_output1 = target_net1(img).unsqueeze(0)

                        for t in range(T):
                            if t == 0:
                                source_out_spikes_counter = source_net(
                                    img).unsqueeze(0)
                                target_out_spikes_counter2 = target_net2(
                                    encoder(img).float()).unsqueeze(0)
                            else:
                                source_out_spikes_counter += source_net(
                                    img).unsqueeze(0)
                                target_out_spikes_counter2 += target_net2(
                                    encoder(img).float()).unsqueeze(0)

                        source_output = source_out_spikes_counter / T
                        target_output2 = target_out_spikes_counter2 / T

                        source_attack_flag = (source_output.max(
                            1)[1] == target_label).float().sum().item()
                        source_success_sum += source_attack_flag

                        target_attack_flag1 = (target_output1.max(
                            1)[1] == target_label).float().sum().item()
                        target_success_sum1 += target_attack_flag1

                        target_attack_flag2 = (target_output2.max(
                            1)[1] == target_label).float().sum().item()
                        target_success_sum2 += target_attack_flag2

                        source_net.reset_()
                        target_net2.reset_()

                        if source_attack_flag > 0.5:
                            print('Source Attack Success')
                        else:
                            print('Source Attack Failure')

                        if target_attack_flag1 > 0.5:
                            print('Target Attack 1 Success')
                        else:
                            print('Target Attack 1 Failure')

                        if target_attack_flag2 > 0.5:
                            print('Target Attack 2 Success')
                        else:
                            print('Target Attack 2 Failure')

                if test_sum >= 270:
                    mean_p /= 270
                    break

        print('Mean Perturbation: %.2f' % mean_p)
        print('source_success_sum: %d' % source_success_sum)
        print('target_success_1_sum: %d' % target_success_sum1)
        print('target_success_2_sum: %d' % target_success_sum2)
        print('test_sum: %d' % test_sum)
        print('source_success_rate: %.2f%%' %
              (100 * source_success_sum / test_sum))
        print('target_success_1_rate: %.2f%%' %
              (100 * target_success_sum1 / test_sum))
        print('target_success_2_rate: %.2f%%' %
              (100 * target_success_sum2 / test_sum))
Ejemplo n.º 6
0
def main():
    # device = input('输入运行的设备,例如“cpu”或“cuda:0”  ')
    # dataset_dir = input('输入保存MNIST数据集的位置,例如“./”  ')
    # class_num = int(input('输入class_num,例如“10”  '))
    # T = int(input('输入仿真时长,例如“50”  '))
    # phase = input('输入算法阶段,例如“BIM”  ')

    device = 'cuda:3'
    dataset_dir = '../../dataset/'
    class_num = 10
    T = 50
    phase = 'BIM'

    torch.cuda.empty_cache()

    encoder = encoding.PoissonEncoder()

    if phase == 'BIM':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # eta = float(input('输入对抗攻击学习率,例如“0.05”  '))
        # attack_type = input('输入攻击类型,例如“UT/T”  ')
        # clip_eps = float(input('输入截断eps,例如“0.01”  '))

        source_model_path = './models/cifar10_ann_v1.pth'
        target_model_path1 = './models/cifar10_img_v1.pth'
        target_model_path2 = './models/cifar10_spike_v1.pth'

        iter_num = 25
        eta = 0.003
        attack_type = 'UT'

        clip_eps = 0.06

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.CIFAR10(root=dataset_dir,
                                                 train=False,
                                                 transform=transform_test,
                                                 download=True),
            batch_size=1,
            shuffle=False,
            drop_last=False)

        p_max = transform_test(np.ones((32, 32, 3))).to(device)
        p_min = transform_test(np.zeros((32, 32, 3))).to(device)

        source_net = ANN_Net().to(device)
        source_net.load_state_dict(torch.load(source_model_path))

        target_net1 = SNN_Net().to(device)
        target_net1.load_state_dict(torch.load(target_model_path1))

        target_net2 = SNN_Net().to(device)
        target_net2.load_state_dict(torch.load(target_model_path2))

        target_net1.eval()
        target_net2.eval()

        mean_p = 0.0
        test_sum = 0
        source_success_sum = 0
        target_success_sum1 = 0
        target_success_sum2 = 0

        if attack_type == 'UT':
            for X, y in test_data_loader:
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)
                img.requires_grad = True

                test_sum += 1

                print('Img %d' % test_sum)

                source_net.train()

                for it in range(iter_num):
                    output = source_net(img).unsqueeze(0)

                    # loss = F.mse_loss(output, F.one_hot(label, class_num).float())
                    loss = F.cross_entropy(output, label)

                    loss.backward()

                    img_grad = torch.sign(img.grad.data)

                    img_adv = clip_by_tensor(img - eta * img_grad,
                                             img_ori - clip_eps,
                                             img_ori + clip_eps, p_min, p_max)

                    img = Variable(img_adv, requires_grad=True)

                source_net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l_norm = torch.max(torch.abs(img_diff)).item()
                    print('Perturbation: %f' % l_norm)

                    mean_p += l_norm

                    source_output = source_net(img).unsqueeze(0)

                    for t in range(T):
                        if t == 0:
                            target_out_spikes_counter1 = target_net1(
                                img).unsqueeze(0)
                            target_out_spikes_counter2 = target_net2(
                                encoder(img).float()).unsqueeze(0)
                        else:
                            target_out_spikes_counter1 += target_net1(
                                img).unsqueeze(0)
                            target_out_spikes_counter2 += target_net2(
                                encoder(img).float()).unsqueeze(0)

                    target_output1 = target_out_spikes_counter1 / T
                    target_output2 = target_out_spikes_counter2 / T

                    source_attack_flag = (source_output.max(1)[1] !=
                                          label).float().sum().item()
                    source_success_sum += source_attack_flag

                    target_attack_flag1 = (target_output1.max(1)[1] !=
                                           label).float().sum().item()
                    target_success_sum1 += target_attack_flag1

                    target_attack_flag2 = (target_output2.max(1)[1] !=
                                           label).float().sum().item()
                    target_success_sum2 += target_attack_flag2

                    target_net1.reset_()
                    target_net2.reset_()

                    if source_attack_flag > 0.5:
                        print('Source Attack Success')
                    else:
                        print('Source Attack Failure')

                    if target_attack_flag1 > 0.5:
                        print('Target Attack 1 Success')
                    else:
                        print('Target Attack 1 Failure')

                    if target_attack_flag2 > 0.5:
                        print('Target Attack 2 Success')
                    else:
                        print('Target Attack 2 Failure')

                if test_sum >= 250:
                    mean_p /= 250
                    break
        else:
            for X, y in test_data_loader:
                for i in range(1, class_num):
                    img, label = X.to(device), y.to(device)
                    img_ori = torch.rand_like(img).copy_(img)
                    img.requires_grad = True

                    target_label = (label + i) % class_num

                    test_sum += 1

                    source_net.train()

                    for it in range(iter_num):
                        output = source_net(img).unsqueeze(0)

                        # loss = F.mse_loss(output, F.one_hot(target_label, class_num).float())
                        loss = F.cross_entropy(output, target_label)

                        loss.backward()

                        img_grad = torch.sign(img.grad.data)

                        img_adv = clip_by_tensor(img - eta * img_grad,
                                                 img_ori - clip_eps,
                                                 img_ori + clip_eps, p_min,
                                                 p_max)

                        img = Variable(img_adv, requires_grad=True)

                    source_net.eval()

                    with torch.no_grad():
                        img_diff = img - img_ori

                        l_norm = torch.max(torch.abs(img_diff)).item()
                        print('Perturbation: %f' % l_norm)

                        mean_p += l_norm

                        source_output = source_net(img).unsqueeze(0)

                        for t in range(T):
                            if t == 0:
                                target_out_spikes_counter1 = target_net1(
                                    img).unsqueeze(0)
                                target_out_spikes_counter2 = target_net2(
                                    encoder(img).float()).unsqueeze(0)
                            else:
                                target_out_spikes_counter1 += target_net1(
                                    img).unsqueeze(0)
                                target_out_spikes_counter2 += target_net2(
                                    encoder(img).float()).unsqueeze(0)

                        target_output1 = target_out_spikes_counter1 / T
                        target_output2 = target_out_spikes_counter2 / T

                        source_attack_flag = (source_output.max(
                            1)[1] == target_label).float().sum().item()
                        source_success_sum += source_attack_flag

                        target_attack_flag1 = (target_output1.max(
                            1)[1] == target_label).float().sum().item()
                        target_success_sum1 += target_attack_flag1

                        target_attack_flag2 = (target_output2.max(
                            1)[1] == target_label).float().sum().item()
                        target_success_sum2 += target_attack_flag2

                        target_net1.reset_()
                        target_net2.reset_()

                        if source_attack_flag > 0.5:
                            print('Source Attack Success')
                        else:
                            print('Source Attack Failure')

                        if target_attack_flag1 > 0.5:
                            print('Target Attack 1 Success')
                        else:
                            print('Target Attack 1 Failure')

                        if target_attack_flag2 > 0.5:
                            print('Target Attack 2 Success')
                        else:
                            print('Target Attack 2 Failure')

                if test_sum >= 270:
                    mean_p /= 270
                    break

        print('Mean Perturbation: %.3f' % mean_p)
        print('source_success_sum: %d' % source_success_sum)
        print('target_success_1_sum: %d' % target_success_sum1)
        print('target_success_2_sum: %d' % target_success_sum2)
        print('test_sum: %d' % test_sum)
        print('source_success_rate: %.2f%%' %
              (100 * source_success_sum / test_sum))
        print('target_success_1_rate: %.2f%%' %
              (100 * target_success_sum1 / test_sum))
        print('target_success_2_rate: %.2f%%' %
              (100 * target_success_sum2 / test_sum))
Ejemplo n.º 7
0
def main():
    gpu_list = input('输入使用的5个gpu,例如“0,1,2,0,3”  ').split(',')
    dataset_dir = input('输入保存CIFAR10数据集的位置,例如“./”  ')
    batch_size = int(input('输入batch_size,例如“64”  '))
    split_sizes = int(input('输入split_sizes,例如“16”  '))
    learning_rate = float(input('输入学习率,例如“1e-3”  '))
    T = int(input('输入仿真时长,例如“50”  '))
    tau = float(input('输入LIF神经元的时间常数tau,例如“100.0”  '))
    train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100”  '))
    log_dir = input('输入保存tensorboard日志文件的位置,例如“./”  ')

    writer = SummaryWriter(log_dir)

    # 初始化数据加载器
    train_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10(
            root=dataset_dir,
            train=True,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True)
    test_data_loader = torch.utils.data.DataLoader(
        dataset=torchvision.datasets.CIFAR10(
            root=dataset_dir,
            train=False,
            transform=torchvision.transforms.ToTensor(),
            download=True),
        batch_size=batch_size,
        shuffle=True,
        drop_last=False)

    # 初始化网络
    net = Net(gpu_list=gpu_list, tau=tau)
    # 使用Adam优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    # 使用泊松编码器
    encoder = encoding.PoissonEncoder()
    train_times = 0
    for _ in range(train_epoch):
        net.train()
        for img, label in train_data_loader:

            label = label.to(net.gpu_list[-1])
            optimizer.zero_grad()

            # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor
            # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数
            for t in range(T):
                if t == 0:
                    out_spikes_counter = net(encoder(img).float(), split_sizes)
                else:
                    out_spikes_counter += net(
                        encoder(img).float(), split_sizes)

            # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率
            out_spikes_counter_frequency = out_spikes_counter / T

            # 损失函数为输出层神经元的脉冲发放频率,与真实类别的交叉熵
            # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0
            loss = F.cross_entropy(out_spikes_counter_frequency, label)
            loss.backward()
            optimizer.step()
            # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
            net.reset_()

            # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
            correct_rate = (out_spikes_counter_frequency.max(1)[1] == label
                            ).float().mean().item()
            writer.add_scalar('train_correct_rate', correct_rate, train_times)
            if train_times % 1024 == 0:
                print(gpu_list, dataset_dir, batch_size, split_sizes,
                      learning_rate, T, tau, train_epoch, log_dir)
                print(sys.argv, 'train_times', train_times,
                      'train_correct_rate', correct_rate)
            train_times += 1
        net.eval()
        with torch.no_grad():
            # 每遍历一次全部数据集,就在测试集上测试一次
            test_sum = 0
            correct_sum = 0
            for img, label in test_data_loader:
                label = label.to(net.gpu_list[-1])

                for t in range(T):
                    if t == 0:
                        out_spikes_counter = net(
                            encoder(img).float(), split_sizes)
                    else:
                        out_spikes_counter += net(
                            encoder(img).float(), split_sizes)

                correct_sum += (out_spikes_counter.max(1)[1] == label
                                ).float().sum().item()
                test_sum += label.numel()
                net.reset_()

            writer.add_scalar('test_correct_rate', correct_sum / test_sum,
                              train_times)
Ejemplo n.º 8
0
def main():
    # device = input('输入运行的设备,例如“cpu”或“cuda:0”  ')
    # dataset_dir = input('输入保存MNIST数据集的位置,例如“./”  ')
    # class_num = int(input('输入class_num,例如“10”  '))
    # T = int(input('输入仿真时长,例如“50”  '))
    # phase = input('输入算法阶段,例如“BIM”  ')

    device = 'cuda:3'
    dataset_dir = '../../dataset/'
    class_num = 10
    T = 50
    phase = 'BIM'

    torch.cuda.empty_cache()

    encoder = encoding.PoissonEncoder()

    if phase == 'BIM':
        # model_path = input('输入模型文件路径,例如“./model.pth”  ')
        # iter_num = int(input('输入对抗攻击的迭代次数,例如“25”  '))
        # gamma = float(input('输入GT的采样因子,例如“0.05”  '))
        # perturbation = float(input('输入扰动幅度,例如“4.0”  '))
        # attack_type = input('输入攻击类型,例如“UT/T”  ')

        source_model_path = './models/mnist_spike_v1.pth'
        target_model_path = './models/mnist_spike_v2.pth'

        gamma = 0.05
        iter_num = 50
        perturbation = 3.1
        attack_type = 'T'

        test_data_loader = torch.utils.data.DataLoader(
            dataset=torchvision.datasets.MNIST(
                root=dataset_dir,
                train=False,
                transform=torchvision.transforms.ToTensor(),
                download=True),
            batch_size=1,
            shuffle=False,
            drop_last=False)

        source_net = Net().to(device)
        source_net.load_state_dict(torch.load(source_model_path))

        target_net = Net().to(device)
        target_net.load_state_dict(torch.load(target_model_path))

        target_net.eval()

        mean_p = 0.0
        test_sum = 0
        source_success_sum = 0
        target_success_sum = 0

        if attack_type == 'UT':
            for X, y in test_data_loader:
                img, label = X.to(device), y.to(device)
                img_ori = torch.rand_like(img).copy_(img)

                test_sum += 1

                print('Img %d' % test_sum)

                source_net.train()

                for it in range(iter_num):
                    spike_train = []

                    for t in range(T):
                        if t == 0:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter = source_net(spike).unsqueeze(0)
                        else:
                            spike = encoder(img).float()
                            spike.requires_grad = True
                            spike_train.append(spike)
                            out_spikes_counter += source_net(spike).unsqueeze(
                                0)

                    out_spikes_counter_frequency = out_spikes_counter / T

                    # loss = F.mse_loss(out_spikes_counter_frequency, F.one_hot(label, class_num).float())
                    loss = F.cross_entropy(out_spikes_counter_frequency, label)

                    loss.backward()

                    ik = torch.zeros_like(spike).to(device)

                    for spike in spike_train:
                        if torch.max(torch.abs(spike.grad.data)) > 1e-32:
                            # print('G2S Converter')

                            grad_sign = torch.sign(spike.grad.data)
                            grad_abs = torch.abs(spike.grad.data)
                            grad_norm = (grad_abs - torch.min(grad_abs)) / (
                                torch.max(grad_abs) - torch.min(grad_abs))
                            grad_mask = torch.bernoulli(grad_norm)
                            G2S = grad_sign * grad_mask
                            G2S_trans = torch.clamp(G2S + spike, 0.0,
                                                    1.0) - spike

                            ik += G2S_trans

                        else:
                            # print('Gradient Trigger')

                            GT = torch.bernoulli(
                                torch.ones_like(spike.grad.data) * gamma)
                            GT_trans = (GT.bool()
                                        ^ spike.bool()).float() - spike

                            ik += GT_trans

                    ik /= T

                    l2_norm = torch.norm(ik.view(ik.size()[0], -1),
                                         dim=1).item()
                    # print('Perturbation: %f' % l2_norm)

                    if l2_norm < perturbation:
                        img = torch.clamp(img + ik, 0.0, 1.0)

                        source_net.reset_()

                        for p in spike_train:
                            p.grad.data.zero_()
                        for p in source_net.parameters():
                            p.grad.data.zero_()

                    else:
                        source_net.reset_()

                        for p in spike_train:
                            p.grad.data.zero_()
                        for p in source_net.parameters():
                            p.grad.data.zero_()

                source_net.eval()

                with torch.no_grad():
                    img_diff = img - img_ori

                    l2_norm = torch.norm(img_diff.view(img_diff.size()[0], -1),
                                         dim=1).item()
                    print('Total Perturbation: %f' % l2_norm)

                    mean_p += l2_norm

                    for t in range(T):
                        if t == 0:
                            source_out_spikes_counter = source_net(
                                encoder(img).float()).unsqueeze(0)
                            target_out_spikes_counter = target_net(
                                encoder(img).float()).unsqueeze(0)
                        else:
                            source_out_spikes_counter += source_net(
                                encoder(img).float()).unsqueeze(0)
                            target_out_spikes_counter += target_net(
                                encoder(img).float()).unsqueeze(0)

                    source_out_spikes_counter_frequency = source_out_spikes_counter / T
                    target_out_spikes_counter_frequency = target_out_spikes_counter / T

                    source_attack_flag = (source_out_spikes_counter.max(1)[1]
                                          != label).float().sum().item()
                    source_success_sum += source_attack_flag

                    target_attack_flag = (target_out_spikes_counter.max(1)[1]
                                          != label).float().sum().item()
                    target_success_sum += target_attack_flag

                    source_net.reset_()
                    target_net.reset_()

                    if source_attack_flag > 0.5:
                        print('Source Attack Success')
                    else:
                        print('Source Attack Failure')

                    if target_attack_flag > 0.5:
                        print('Target Attack Success')
                    else:
                        print('Target Attack Failure')

                if test_sum >= 250:
                    mean_p /= 250
                    break
        else:
            for X, y in test_data_loader:
                for i in range(1, class_num):
                    img, label = X.to(device), y.to(device)
                    img_ori = torch.rand_like(img).copy_(img)

                    target_label = (label + i) % class_num

                    test_sum += 1

                    source_net.train()

                    for it in range(iter_num):
                        spike_train = []

                        for t in range(T):
                            if t == 0:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter = source_net(
                                    spike).unsqueeze(0)
                            else:
                                spike = encoder(img).float()
                                spike.requires_grad = True
                                spike_train.append(spike)
                                out_spikes_counter += source_net(
                                    spike).unsqueeze(0)

                        out_spikes_counter_frequency = out_spikes_counter / T

                        loss = F.mse_loss(
                            out_spikes_counter_frequency,
                            F.one_hot(target_label, class_num).float())
                        # loss = F.cross_entropy(out_spikes_counter_frequency, target_label)

                        loss.backward()

                        ik = torch.zeros_like(spike).to(device)

                        for spike in spike_train:
                            if torch.max(torch.abs(spike.grad.data)) > 1e-32:
                                # print('G2S Converter')

                                grad_sign = -torch.sign(spike.grad.data)
                                grad_abs = torch.abs(spike.grad.data)
                                grad_norm = (grad_abs - torch.min(grad_abs)
                                             ) / (torch.max(grad_abs) -
                                                  torch.min(grad_abs))
                                grad_mask = torch.bernoulli(grad_norm)
                                G2S = grad_sign * grad_mask
                                G2S_trans = torch.clamp(G2S + spike, 0.0,
                                                        1.0) - spike

                                ik += G2S_trans

                            else:
                                # print('Gradient Trigger')

                                GT = torch.bernoulli(
                                    torch.ones_like(spike.grad.data) * gamma)
                                GT_trans = (GT.bool()
                                            ^ spike.bool()).float() - spike

                                ik += GT_trans

                        ik /= T

                        l2_norm = torch.norm(ik.view(ik.size()[0], -1),
                                             dim=1).item()
                        # print('Perturbation: %f' % l2_norm)

                        if l2_norm < perturbation:
                            img = torch.clamp(img + ik, 0.0, 1.0)

                            source_net.reset_()

                            for p in spike_train:
                                p.grad.data.zero_()
                            for p in source_net.parameters():
                                p.grad.data.zero_()

                        else:
                            source_net.reset_()

                            for p in spike_train:
                                p.grad.data.zero_()
                            for p in source_net.parameters():
                                p.grad.data.zero_()

                    source_net.eval()

                    with torch.no_grad():
                        img_diff = img - img_ori

                        l2_norm = torch.norm(img_diff.view(
                            img_diff.size()[0], -1),
                                             dim=1).item()
                        print('Total Perturbation: %f' % l2_norm)

                        mean_p += l2_norm

                        for t in range(T):
                            if t == 0:
                                source_out_spikes_counter = source_net(
                                    encoder(img).float()).unsqueeze(0)
                                target_out_spikes_counter = target_net(
                                    encoder(img).float()).unsqueeze(0)
                            else:
                                source_out_spikes_counter += source_net(
                                    encoder(img).float()).unsqueeze(0)
                                target_out_spikes_counter += target_net(
                                    encoder(img).float()).unsqueeze(0)

                        source_out_spikes_counter_frequency = source_out_spikes_counter / T
                        target_out_spikes_counter_frequency = target_out_spikes_counter / T

                        source_attack_flag = (source_out_spikes_counter.max(
                            1)[1] == target_label).float().sum().item()
                        source_success_sum += source_attack_flag

                        target_attack_flag = (target_out_spikes_counter.max(
                            1)[1] == target_label).float().sum().item()
                        target_success_sum += target_attack_flag

                        source_net.reset_()
                        target_net.reset_()

                        if source_attack_flag > 0.5:
                            print('Source Attack Success')
                        else:
                            print('Source Attack Failure')

                        if target_attack_flag > 0.5:
                            print('Target Attack Success')
                        else:
                            print('Target Attack Failure')
                        '''
                        samples = img.permute(0, 2, 3, 1).data.cpu().numpy()

                        im = np.repeat(samples[0], 3, axis=2)
                        im_path = 'demo/%d_to_%d.png' % (label.item(), target_label.item())
                        print(im_path)
                        print(out_spikes_counter_frequency)
                        plt.imsave(im_path, im)
                        '''

                if test_sum >= 270:
                    mean_p /= 270
                    break

        print('Mean Perturbation: %.2f' % mean_p)
        print('source_success_sum: %d' % source_success_sum)
        print('target_success_sum: %d' % target_success_sum)
        print('test_sum: %d' % test_sum)
        print('source_success_rate: %.2f%%' %
              (100 * source_success_sum / test_sum))
        print('target_success_rate: %.2f%%' %
              (100 * target_success_sum / test_sum))