Example #1
0
        model.to(DEVICE)

    if args.adv is None and val_accs.avg >= best_acc:
        best_acc = val_accs.avg
        best_epoch = epoch
        best_dict = deepcopy(model.state_dict())

    if not (epoch + 1) % args.save_freq:
        save_checkpoint(model.state_dict(),
                        os.path.join(
                            args.save_folder, args.save_name +
                            'acc{}_{}.pth'.format(val_accs.avg, (epoch + 1))),
                        cpu=True)

if args.adv is None:
    model.load_state_dict(best_dict)

test_accs = AverageMeter()
test_losses = AverageMeter()

with torch.no_grad():
    for i, (images, labels) in enumerate(tqdm.tqdm(test_loader, ncols=80)):
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        logits = model(images)
        loss = F.cross_entropy(logits, labels)

        test_accs.append((logits.argmax(1) == labels).float().mean().item())
        test_losses.append(loss.item())

if args.adv is not None:
    pth_file = '/media/unknown/Data/PLP/fast_adv/defenses/weights/best/2_mixed_attention_cifar10_ep_12_val_acc0.8895.pth'
    #pth_file = '/media/unknown/Data/PLP/fast_adv/defenses/weights/best/2_2AT_cifar10_ep_29_val_acc0.8870.pth'

    model = NormalizedModel(model=m, mean=image_mean, std=image_std).to(device)
    #model = model.to(device)
    print('loading data for defense using %s ....' % target_model)

    test_loader = load_data_for_defense()['dev_data']

    # pth_file = os.path.join(weights_path, 'cifar10acc0.9090267625855811_40.pth')#18.24944_17.2555_ep_31__acc0.7470.pthep_36_val_acc0.9844.pthep_36_val_acc0.9844.pth#glob.glob(os.path.join(weights_path, 'cifar10_20_0.73636.pth'))#[0]
    print('loading weights from : ', pth_file)
    #model_dict = torch.load('jpeg_weight/18.5459_0.9_jpeg_WRN_DDNacc0.9553740539334037_20_0.60.pth')
    #model_dict = torch.load('jpeg_weight/ALP_smooth_p_44_val_acc0.9409.pth')
    #model_dict = torch.load('jpeg_weight/JPEG_ALP_smooth_acc0.9401_all_0.786.pth')
    model_dict = torch.load(pth_file)
    model.load_state_dict(model_dict, False)
    model.eval()
    test_accs = AverageMeter()
    test_losses = AverageMeter()
    widgets = [
        'test :',
        Percentage(), ' ',
        Bar('#'), ' ',
        Timer(), ' ',
        ETA(), ' ',
        FileTransferSpeed()
    ]
    pbar = ProgressBar(widgets=widgets)
    with torch.no_grad():
        for batch_data in pbar(test_loader):
            images, labels = batch_data['image'].to(
Example #3
0
        weight_025conv_mixatten='/media/unknown/Data/PLP/fast_adv/defenses/weights/best/0.25MixedAttention_mixed_attention_cifar10_ep_50_val_acc0.8720.pth'
        weight_05conv_mixatten = '/media/unknown/Data/PLP/fast_adv/defenses/weights/shape_0.5_cifar10_mixed_Attention/cifar10acc0.8434999763965607_130.pth'
        weight_1conv_mixatten = '/media/unknown/Data/PLP/fast_adv/defenses/weights/best/1MixedAttention_mixed_attention_cifar10_ep_25_val_acc0.7080.pth'

        weight_shape_alp='/media/unknown/Data/PLP/fast_adv/defenses/weights/best/shape_ALP_cifar10_ep_79_val_acc0.7625.pth'
        weight_attention = '/media/unknown/Data/PLP/fast_adv/defenses/weights/cifar10_Attention/cifar10acc0.8729999780654907_120.pth'

        weight_025conv_mixatten_ALP = '/media/unknown/Data/PLP/fast_adv/defenses/weights/best/0.25Mixed+ALP_cifar10_ep_85_val_acc0.8650.pth'

        weight_smooth = '/media/unknown/Data/PLP/fast_adv/defenses/weights/best/2random_smooth_cifar10_ep_120_val_acc0.8510.pth'
        weight_05smooth = '/media/unknown/Data/PLP/fast_adv/defenses/weights/shape_0.5_random/cifar10acc0.6944999784231186_50.pth'
        weight_025smooth = '/media/unknown/Data/PLP/fast_adv/defenses/weights/best/0.25random_smooth_cifar10_ep_146_val_acc0.8070.pth'
        weight_1smooth = '/media/unknown/Data/PLP/fast_adv/defenses/weights/best/1random_smooth_cifar10_ep_107_val_acc0.5380.pth'
        print('loading weights from : ', weight_AT)
        model_dict = torch.load(weight_AT)
        model.load_state_dict(model_dict)
        model.eval()
        model_dict2 = torch.load(weight_025conv_mixatten_ALP)
        model2.load_state_dict(model_dict2)
        model2.eval()
        test_accs = AverageMeter()
        test_losses = AverageMeter()
        widgets = ['test :', Percentage(), ' ', Bar('#'), ' ', Timer(),
                   ' ', ETA(), ' ', FileTransferSpeed()]
        pbar = ProgressBar(widgets=widgets)
        with torch.no_grad():
            for batch_data in pbar(test_loader):
                images, labels = batch_data['image'].to(device), batch_data['label_idx'].to(device)
                noise = torch.randn_like(images, device='cuda') * 0.2
                image_shape = images + noise
                #image_shape = torch.renorm(image_shape - images, p=2, dim=0, maxnorm=1) + images
Example #4
0
# ########
# model2=NormalizedModel(model=m, mean=image_mean, std=image_std).to(DEVICE)  # keep images in the [0, 1] range
# model_dict2 = torch.load('../defenses/weights/AT_cifar10_clean0.879_adv.pth')
# model2.load_state_dict(model_dict2)
# model3=NormalizedModel(model=m, mean=image_mean, std=image_std).to(DEVICE)  # keep images in the [0, 1] range
# model_dict3 = torch.load('../defenses/weights/best/ALP_cifar10_ep_39_val_acc0.8592.pth')
# model3.load_state_dict(model_dict3)
# model4=NormalizedModel(model=m, mean=image_mean, std=image_std).to(DEVICE)  # keep images in the [0, 1] range
# model_dict4 = torch.load('../defenses/weights/best/PLP1_cifar10_ep_29_val_acc0.8636.pth')
# model4.load_state_dict(model_dict4)

weight_norm = '../defenses/weights/best/2Norm_cifar10_ep_184_val_acc0.9515.pth'
model5 = NormalizedModel(model=m, mean=image_mean, std=image_std).to(
    DEVICE)  # keep images in the [0, 1] range
model_dict5 = torch.load(weight_norm)
model5.load_state_dict(model_dict5)

#model.eval()
#with torch.no_grad():
for i, (images, labels) in enumerate(tqdm.tqdm(test_loader, ncols=80)):
    images, labels = images.to(DEVICE), labels.to(DEVICE)
    logits = model5(images)
    # loss = F.cross_entropy(logits, labels)
    # print(logits)
    test_accs = AverageMeter()
    test_losses = AverageMeter()
    test_accs.append((logits.argmax(1) == labels).float().mean().item())

    ################ADV########################
    attacker = DDN(steps=100, device=DEVICE)
    attacker2 = DeepFool(device=DEVICE)