def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total_epsilon = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = net(
            utils.random_mask_batch_one_sample(inputs,
                                               args.band_size,
                                               reuse_noise=True))
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()

        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total += targets.size(0)

        train_loss += loss.item()

        progress_bar(
            batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (train_loss /
             (batch_idx + 1), 100. * correct / total, correct, total))
def test():
    global best_acc
    correctclean = 0
    correctattacked = 0
    cert_correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(device), targets.to(device)
        total += targets.size(0)

        save_image(make_grid(inputs, nrow=10),
                   "baseline_" + str(batch_idx) + ".jpg")
        attacked = attacker.perturb(inputs,
                                    targets,
                                    float('inf'),
                                    random_count=args.randomizations)
        predictionsclean, certyn = utils.predict_and_certify(
            inputs,
            net,
            args.band_size,
            args.size_of_attack,
            10,
            threshold=args.threshhold)
        predictionsattacked, certynx = utils.predict_and_certify(
            attacked,
            net,
            args.band_size,
            args.size_of_attack,
            10,
            threshold=args.threshhold)

        correctclean += (predictionsclean.eq(targets)).sum().item()
        correctattacked += (predictionsattacked.eq(targets)).sum().item()
        cert_correct += (predictionsclean.eq(targets) & certyn).sum().item()
        save_image(make_grid(attacked, nrow=10),
                   "attacks_" + str(batch_idx) + ".jpg")

        progress_bar(
            batch_idx, len(testloader),
            'Clean Acc: %.3f%% (%d/%d) Cert: %.3f%% (%d/%d) Adv Acc: %.3f%% (%d/%d)'
            % ((100. * correctclean) / total, correctclean, total,
               (100. * cert_correct) / total, cert_correct, total,
               (100. * correctattacked) / total, correctattacked, total))
    print('Using band size ' + str(args.band_size) + ' with threshhold ' +
          str(args.threshhold))
    print('Size of Attack Patch ' + str(args.size_of_attack) + '*' +
          str(args.size_of_attack))
    print('Total images: ' + str(total))
    print('Clean Correct: ' + str(correctclean) + ' (' +
          str((100. * correctclean) / total) + '%)')
    print('Certified: ' + str(cert_correct) + ' (' +
          str((100. * cert_correct) / total) + '%)')
    print('Attacked Correct: ' + str(correctattacked) + ' (' +
          str((100. * correctattacked) / total) + '%)')
def test():
    global best_acc
    correct = 0

    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            total += targets.size(0)
            outputs = net(
                utils.random_mask_batch_one_sample(inputs,
                                                   args.block_size,
                                                   reuse_noise=False))
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            progress_bar(
                batch_idx, len(testloader), ' Acc: %.3f%% (%d/%d)' %
                (100. * correct / total, correct, total))
            print('Correct: ' + str(correct) + ' out of ' + str(total))
Beispiel #4
0
def test():
    global best_acc
    correct = 0
    cert_correct = 0
    cert_incorrect = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            total += targets.size(0)
            predictions, certyn = utils.predict_and_certify(
                inputs,
                net,
                args.keep,
                args.bandwidth,
                args.size_to_certify,
                10,
                threshold=args.threshhold)

            correct += (predictions.eq(targets)).sum().item()
            cert_correct += (predictions.eq(targets) & certyn).sum().item()
            cert_incorrect += (~predictions.eq(targets) & certyn).sum().item()

            progress_bar(
                batch_idx, len(testloader),
                'Acc: %.3f%% (%d/%d) Cert: %.3f%% (%d/%d)' %
                ((100. * correct) / total, correct, total,
                 (100. * cert_correct) / total, cert_correct, total))

    print('Band size: ' + str(args.bandwidth) + ' threshhold: ' +
          str(args.threshhold) + ' num bands: ' + str(args.keep))
    print('Total images: ' + str(total))
    print('Correct: ' + str(correct) + ' (' + str((100. * correct) / total) +
          '%)')
    print('Certified Correct class: ' + str(cert_correct) + ' (' +
          str((100. * cert_correct) / total) + '%)')
    print('Certified Wrong class: ' + str(cert_incorrect) + ' (' +
          str((100. * cert_incorrect) / total) + '%)')
Beispiel #5
0
def test():
    global best_acc
    correctclean = 0
    correctattacked = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        inputs, targets = inputs.to(device), targets.to(device)
        total += targets.size(0)

        save_image(make_grid(inputs, nrow=10),
                   "baseline_baseline_" + str(batch_idx) + ".jpg")
        attacked = attacker.perturb(inputs,
                                    targets,
                                    float('inf'),
                                    random_count=args.randomizations)
        _, predictionsclean = net(inputs).max(1)
        _, predictionsattacked = net(attacked).max(1)

        correctclean += (predictionsclean.eq(targets)).sum().item()
        correctattacked += (predictionsattacked.eq(targets)).sum().item()
        save_image(make_grid(attacked, nrow=10),
                   "attacks_baseline_" + str(batch_idx) + ".jpg")

        progress_bar(
            batch_idx, len(testloader),
            'Clean Acc: %.3f%% (%d/%d) Adv Acc: %.3f%% (%d/%d)' %
            ((100. * correctclean) / total, correctclean, total,
             (100. * correctattacked) / total, correctattacked, total))
    print('Baseline')
    print('Size of Attack Patch ' + str(args.size_of_attack) + '*' +
          str(args.size_of_attack))
    print('Total images: ' + str(total))
    print('Clean Correct: ' + str(correctclean) + ' (' +
          str((100. * correctclean) / total) + '%)')
    print('Attacked Correct: ' + str(correctattacked) + ' (' +
          str((100. * correctattacked) / total) + '%)')
Beispiel #6
0
if (args.model == 'resnet50'):
    net = resnet.ResNet50()
elif (args.model == 'resnet18'):
    net = resnet.ResNet18()

net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

assert os.path.isdir(checkpoint_dir), 'Error: no checkpoint directory found!'
resume_file = '{}/{}'.format(checkpoint_dir, args.checkpoint)
assert os.path.isfile(resume_file)
checkpoint = torch.load(resume_file)
net.load_state_dict(checkpoint['net'])

net.eval()
all_batches = []
for batch_idx, (inputs, targets) in enumerate(testloader):
    inputs, targets = inputs.to(device), targets.to(device)
    with torch.no_grad():
        #breakpoint()
        batch_radii = utils.certify(inputs, targets, net, args.alpha,
                                    args.band_size, args.attack_size,
                                    args.predsamples, args.boundsamples)
        all_batches.append(batch_radii)
        progress_bar(batch_idx, len(testloader))
out = torch.cat(all_batches)
print('band size: ' + str(args.band_size))
print('certify correct: ' + str(float((out == 1).sum()) / out.shape[0]))
print('certify wrong: ' + str(float((out == -1).sum()) / out.shape[0]))