Exemple #1
0
def main(_):
    # Load training and test data
    data = ld_mnist()

    # Instantiate model, loss, and optimizer for training
    if FLAGS.model == "cnn":
        net = CNN(in_channels=1)

    elif FLAGS.model == "pynet":
        net = PyNet(in_channels=1)
    else:
        raise NotImplementedError

    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        net = net.cuda()
    loss_fn = torch.nn.CrossEntropyLoss(reduction="mean")
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

    # Train vanilla model
    net.train()
    for epoch in range(1, FLAGS.nb_epochs + 1):
        train_loss = 0.0
        for x, y in data.train:
            x, y = x.to(device), y.to(device)
            if FLAGS.adv_train:
                # Replace clean example with adversarial example for adversarial training
                x = projected_gradient_descent(net, x, FLAGS.eps, 0.01, 40,
                                               np.inf)
            optimizer.zero_grad()
            loss = loss_fn(net(x), y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        print("epoch: {}/{}, train loss: {:.3f}".format(
            epoch, FLAGS.nb_epochs, train_loss))

    # Evaluate on clean and adversarial data
    net.eval()
    report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)
    for x, y in data.test:
        x, y = x.to(device), y.to(device)
        x_fgm = fast_gradient_method(net, x, FLAGS.eps, np.inf)
        x_pgd = projected_gradient_descent(net, x, FLAGS.eps, 0.01, 40, np.inf)
        _, y_pred = net(x).max(1)  # model prediction on clean examples
        _, y_pred_fgm = net(x_fgm).max(
            1)  # model prediction on FGM adversarial examples
        _, y_pred_pgd = net(x_pgd).max(
            1)  # model prediction on PGD adversarial examples
        report.nb_test += y.size(0)
        report.correct += y_pred.eq(y).sum().item()
        report.correct_fgm += y_pred_fgm.eq(y).sum().item()
        report.correct_pgd += y_pred_pgd.eq(y).sum().item()
    print("test acc on clean examples (%): {:.3f}".format(
        report.correct / report.nb_test * 100.0))
    print("test acc on FGM adversarial examples (%): {:.3f}".format(
        report.correct_fgm / report.nb_test * 100.0))
    print("test acc on PGD adversarial examples (%): {:.3f}".format(
        report.correct_pgd / report.nb_test * 100.0))
Exemple #2
0
def attack(model, dataset, attack_alg='fgsm', eps=0.1, device='cuda'):
    accuracy = Accuracy().to(device)
    model = model.to(device)
    model.eval()
    for x, y in dataset:
        x, y = x.to(device), y.to(device)
        # if model.binary:
        #     x[x >= 0.5] = 1
        #     x[x < 0.5] = 0
        x_fgsm = fast_gradient_method(model,
                                      x,
                                      eps,
                                      np.inf,
                                      clip_min=0,
                                      clip_max=1)
        _, y_pred_fgsm = model(x_fgsm).max(1)
        accuracy(y_pred_fgsm, y)
    acc = accuracy.compute()
    print(acc)
Exemple #3
0
def projected_gradient_descent(
    model_fn,
    x,
    eps,
    eps_iter,
    nb_iter,
    norm,
    clip_min=None,
    clip_max=None,
    y=None,
    targeted=False,
    rand_init=True,
    rand_minmax=None,
    sanity_checks=True,
):
    """
    This class implements either the Basic Iterative Method
    (Kurakin et al. 2016) when rand_init is set to False. or the
    Madry et al. (2017) method if rand_init is set to True.
    Paper link (Kurakin et al. 2016): https://arxiv.org/pdf/1607.02533.pdf
    Paper link (Madry et al. 2017): https://arxiv.org/pdf/1706.06083.pdf
    :param model_fn: a callable that takes an input tensor and returns the model logits.
    :param x: input tensor.
    :param eps: epsilon (input variation parameter); see https://arxiv.org/abs/1412.6572.
    :param eps_iter: step size for each attack iteration
    :param nb_iter: Number of attack iterations.
    :param norm: Order of the norm (mimics NumPy). Possible values: np.inf, 1 or 2.
    :param clip_min: (optional) float. Minimum float value for adversarial example components.
    :param clip_max: (optional) float. Maximum float value for adversarial example components.
    :param y: (optional) Tensor with true labels. If targeted is true, then provide the
              target label. Otherwise, only provide this parameter if you'd like to use true
              labels when crafting adversarial samples. Otherwise, model predictions are used
              as labels to avoid the "label leaking" effect (explained in this paper:
              https://arxiv.org/abs/1611.01236). Default is None.
    :param targeted: (optional) bool. Is the attack targeted or untargeted?
              Untargeted, the default, will try to make the label incorrect.
              Targeted will instead try to move in the direction of being more like y.
    :param rand_init: (optional) bool. Whether to start the attack from a randomly perturbed x.
    :param rand_minmax: (optional) bool. Support of the continuous uniform distribution from
              which the random perturbation on x was drawn. Effective only when rand_init is
              True. Default equals to eps.
    :param sanity_checks: bool, if True, include asserts (Turn them off to use less runtime /
              memory or for unit tests that intentionally pass strange input)
    :return: a tensor for the adversarial example
    """
    if norm == 1:
        raise NotImplementedError(
            "It's not clear that FGM is a good inner loop"
            " step for PGD when norm=1, because norm=1 FGM "
            " changes only one pixel at a time. We need "
            " to rigorously test a strong norm=1 PGD "
            "before enabling this feature."
        )
    if norm not in [np.inf, 2]:
        raise ValueError("Norm order must be either np.inf or 2.")
    if eps < 0:
        raise ValueError(
            "eps must be greater than or equal to 0, got {} instead".format(eps)
        )
    if eps == 0:
        return x
    if eps_iter < 0:
        raise ValueError(
            "eps_iter must be greater than or equal to 0, got {} instead".format(
                eps_iter
            )
        )
    if eps_iter == 0:
        return x

    assert eps_iter <= eps, (eps_iter, eps)
    if clip_min is not None and clip_max is not None:
        if clip_min > clip_max:
            raise ValueError(
                "clip_min must be less than or equal to clip_max, got clip_min={} and clip_max={}".format(
                    clip_min, clip_max
                )
            )

    asserts = []

    # If a data range was specified, check that the input was in that range
    if clip_min is not None:
        assert_ge = torch.all(
            torch.ge(x, torch.tensor(clip_min, device=x.device, dtype=x.dtype))
        )
        asserts.append(assert_ge)

    if clip_max is not None:
        assert_le = torch.all(
            torch.le(x, torch.tensor(clip_max, device=x.device, dtype=x.dtype))
        )
        asserts.append(assert_le)

    # Initialize loop variables
    if rand_init:
        if rand_minmax is None:
            rand_minmax = eps
        eta = torch.zeros_like(x).uniform_(-rand_minmax, rand_minmax)
    else:
        eta = torch.zeros_like(x)

    # Clip eta
    eta = clip_eta(eta, norm, eps)
    adv_x = x + eta
    if clip_min is not None or clip_max is not None:
        adv_x = torch.clamp(adv_x, clip_min, clip_max)

    if y is None:
        # Using model predictions as ground truth to avoid label leaking
        _, y = torch.max(model_fn(x), 1)

    i = 0
    while i < nb_iter:
        adv_x = fast_gradient_method(
            model_fn,
            adv_x,
            eps_iter,
            norm,
            clip_min=clip_min,
            clip_max=clip_max,
            y=y,
            targeted=targeted,
        )

        # Clipping perturbation eta to norm norm ball
        eta = adv_x - x
        eta = clip_eta(eta, norm, eps)
        adv_x = x + eta

        # Redo the clipping.
        # FGM already did it, but subtracting and re-adding eta can add some
        # small numerical error.
        if clip_min is not None or clip_max is not None:
            adv_x = torch.clamp(adv_x, clip_min, clip_max)
        i += 1

    asserts.append(eps_iter <= eps)
    if norm == np.inf and clip_min is not None:
        # TODO necessary to cast clip_min and clip_max to x.dtype?
        asserts.append(eps + clip_min <= clip_max)

    if sanity_checks:
        assert np.all(asserts)
    return adv_x
Exemple #4
0
def validate(args):
    _logger.info(f'\n\n ---------------EVALUATION {args.eps}------------------------------- \n\n')
    _logger.info("Argument parser collected the following arguments:")
    for arg in vars(args):
        _logger.info(f"    {arg}:{getattr(args, arg)}")
    _logger.info("\n")

    # might as well try to validate something
    args.pretrained = args.pretrained or not args.checkpoint
    args.prefetcher = not args.no_prefetcher
    amp_autocast = suppress  # do nothing
    if args.amp:
        if has_native_amp:
            args.native_amp = True
        elif has_apex:
            args.apex_amp = True
        else:
            _logger.warning("Neither APEX or Native Torch AMP is available.")
    assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
    if args.native_amp:
        amp_autocast = torch.cuda.amp.autocast
        _logger.info('Validating in mixed precision with native PyTorch AMP.')
    elif args.apex_amp:
        _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
    else:
        _logger.info('Validating in float32. AMP not enabled.')

    if args.legacy_jit:
        set_jit_legacy()

    # create model
    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        in_chans=3,
        global_pool=args.gp,
        scriptable=args.torchscript)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes

    if args.checkpoint:
        load_checkpoint(model, args.checkpoint, args.use_ema)

    param_count = sum([m.numel() for m in model.parameters()])        
    _logger.info(
        f'Model {args.model} created, param count: {param_count} ({(float(param_count)/(10.0**6)):.1f} M)'
    )

    data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
    test_time_pool = False
    if not args.no_test_pool:
        model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

    if args.torchscript:
        torch.jit.optimized_execution(True)
        model = torch.jit.script(model)

    model = model.cuda()
    if args.apex_amp:
        model = amp.initialize(model, opt_level='O1')

    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    if args.num_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))

    criterion = nn.CrossEntropyLoss().cuda()

    dataset = create_dataset(
        root=args.data_dir, name=args.dataset, split=args.split,
        load_bytes=args.tf_preprocessing, class_map=args.class_map)

    if args.valid_labels:
        with open(args.valid_labels, 'r') as f:
            valid_labels = {int(line.rstrip()) for line in f}
            valid_labels = [i in valid_labels for i in range(args.num_classes)]
    else:
        valid_labels = None

    if args.real_labels:
        real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
    else:
        real_labels = None

    crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
    loader = create_loader(
        dataset,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        crop_pct=crop_pct,
        pin_memory=args.pin_mem,
        tf_preprocessing=args.tf_preprocessing)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1_fgm_ae = AverageMeter()
    top5_fgm_ae = AverageMeter()
    top1_pgd_ae = AverageMeter()
    top5_pgd_ae = AverageMeter()

    model.eval()
    #with torch.no_grad():# TODO Requires grad
    # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
    input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
    if args.channels_last:
        input = input.contiguous(memory_format=torch.channels_last)
    model(input)
    end = time.time()
    for batch_idx, (input, target) in enumerate(loader):
        if args.no_prefetcher:
            target = target.cuda()
            input = input.cuda()
        if args.channels_last:
            input = input.contiguous(memory_format=torch.channels_last)

        # compute output
        with amp_autocast():
            output = model(input)

        if valid_labels is not None:
            output = output[:, valid_labels]
        loss = criterion(output, target)

        if real_labels is not None:
            real_labels.add_result(output)

        # TODO <---------------------
        # Generate adversarial examples for current inputs
        input_fgm_ae = fast_gradient_method(
            model_fn=model,
            x=input,
            eps=args.eps,
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        input_pgd_ae = projected_gradient_descent(
            model_fn=model,
            x=input, 
            eps=args.eps, 
            eps_iter=0.01, 
            nb_iter=40, 
            norm=np.inf,
            clip_min=None,
            clip_max=None,
        )
        # Predict with Adversarial Examples
        with torch.no_grad():
            with amp_autocast():
                output_fgm_ae = model(input_fgm_ae)
                output_pgd_ae = model(input_pgd_ae)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
        losses.update(loss.item(), input.size(0))
        top1.update(acc1.item(), input.size(0))
        top5.update(acc5.item(), input.size(0))

        acc1_fgm_ae, acc5_fgm_ae = accuracy(output_fgm_ae.detach(), target, topk=(1, 5))
        acc1_pgd_ae, acc5_pgd_ae = accuracy(output_pgd_ae.detach(), target, topk=(1, 5))
        top1_fgm_ae.update(acc1_fgm_ae.item(), input.size(0))
        top5_fgm_ae.update(acc5_fgm_ae.item(), input.size(0))
        top1_pgd_ae.update(acc1_pgd_ae.item(), input.size(0))
        top5_pgd_ae.update(acc5_pgd_ae.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % args.log_freq == 0:
            _logger.info(
                'Test: [{0:>4d}/{1}]  '
                'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s)  '
                'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f})  '
                'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f})  '
                'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
                    batch_idx, len(loader), batch_time=batch_time,
                    rate_avg=input.size(0) / batch_time.avg,
                    loss=losses, top1=top1, top5=top5))

    if real_labels is not None:
        raise NotImplementedError # TODO NOt modified for the adversarial examples mode 
        # real labels mode replaces topk values at the end
        top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
    else:
        top1a, top5a = top1.avg, top5.avg
        top1a_fgm_ae, top5a_fgm_ae = top1_fgm_ae.avg, top5_fgm_ae.avg
        top1a_pgd_ae, top5a_pgd_ae = top1_pgd_ae.avg, top5_pgd_ae.avg
    results = OrderedDict(
        top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
        top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
        top1_fgm_ae=round(top1a_fgm_ae, 4),
        top5_fgm_ae=round(top5a_fgm_ae, 4),
        top1_pgd_ae=round(top1a_pgd_ae, 4),
        top5_pgd_ae=round(top5a_pgd_ae, 4),
        param_count=round(param_count / 1e6, 2),
        img_size=data_config['input_size'][-1],
        cropt_pct=crop_pct,
        interpolation=data_config['interpolation'])

    _logger.info(' * [Regular] Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
       results['top1'], results['top1_err'], results['top5'], results['top5_err']))

    _logger.info(' * [FGM Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_fgm_ae'], results['top5_fgm_ae']))
    _logger.info(' * [PGD Adversarial Attack] Acc@1 {:.3f}  Acc@5 {:.3f} '.format(
       results['top1_pgd_ae'], results['top5_pgd_ae']))

    return results
def run(
    #################################################################################################
    #                                          Parameters:                                          #
    #################################################################################################
    MODELS = None,
    PRINT_OUT = True,    # Print out results at end

    # imagenet-c:
    CORRUPT_IMG = False,
    COR_NUM = 7,    # 7 is now, 8 is frost - but they both error out :/
    COR_SEVERITY = 5,

    # Adversarial Attacks:
    adversarial_attack = False,
    adversarial_type = 'hop_skip',  # 'fast' or 'projected' recommended

    # Bit-flipping corruptions:
    stuck_at_faults = 0,  # This many bits will have "stuck-at faults" in the weights, permanently stuck at either 1 or 0
    weights_BER = 0,  # Bit Error Rate for weights (applied each batch, assuming weights are reloaded for each batch)
    activation_BER = 0,  # Bit Error Rate for activations, i.e. 1e-9 = ~(1 in 1000000000) errors in the activations

    # Model parameters:
    num_batches = 1,  # Number of loops performed, each with a new batch of images
    batch_size = 4,  # Number of images processed in a batch (in parallel)
    val_image_dir = 'val/',  # The directory where validation images are stored
    voting_heuristic = 'sum all',  # Determines the algorithm used to predict between multiple models

    cuda = torch.cuda.is_available()
):
    if MODELS is None:
        MODELS = ['resnext101_32x8d', 'densenet161', 'inception_v3']  # For an ensemble, put >1 network here
    reset_bit_flip_counters()   # Do this at start in case calling run() multiple times

    #################################################################################################
    #                                           Runtime:                                            #
    #################################################################################################
    # Instantiate the model(s)
    networks = []
    for i, m in enumerate(MODELS):
        net = get_model(m)
        if cuda:
            net = net.cuda()
        net.name = str(i) + '_' + net.__class__.__name__  # Give the net a unique name (used by bit_flipping.py)
        net.eval()  # Put in evaluation mode (already pretrained)
        if stuck_at_faults != 0:
            net = flip_n_bits_in_weights(stuck_at_faults, net)  # Introduce stuck-ats
        if activation_BER != 0:  # If nonzero chance of activation bit flips
            net = add_activation_bit_flips(net, activation_BER)  # Add layers to flip activation bits
        networks.append(net)

    if CORRUPT_IMG:
        print('Corrupting with COR_NUM: ' + str(COR_NUM) + ' and COR_SEVERITY: ' + str(COR_SEVERITY))

    # Run each batch
    total_correct = 0
    for batch_num in range(num_batches):
        # Load images and prepare them in a batch
        image_paths = random.sample(os.listdir(val_image_dir), batch_size)
        gt_labels = torch.tensor([get_label(image) for image in image_paths])  # Ground-truth label for each image

        batch_t = torch.empty((batch_size, 3, 224, 224))  # Shape of [N, C, H, W]
        for i in range(batch_size):
            img = Image.open(val_image_dir + '/' + image_paths[i]).convert("RGB")
            img = toSizeCenter(img)
            if CORRUPT_IMG:
                pic_np = np.array(img)  # numpy arr for corruption
                pic_np = corrupt(pic_np, severity=COR_SEVERITY, corruption_number=COR_NUM)  # See Readme for Calls
                img = Image.fromarray(np.uint8(pic_np))  # Back to PIL
            img_t = toTensor(img)
            # img_t = fast_gradient_method(networks[0], img_t, eps=0.25, norm=np.inf, sanity_checks=True)
            batch_t[i, :, :, :] = img_t
        if cuda:
            batch_t = batch_t.cuda()
        if adversarial_attack:
            if adversarial_type == 'fast':
                batch_t = fast_gradient_method(networks[0], batch_t, eps=0.25, norm=np.inf, sanity_checks=True)   # ~25%
            elif adversarial_type == 'projected':
                batch_t = projected_gradient_descent(networks[0], batch_t, 0.25, 0.01, 40, np.inf)                  # 0.25, 0.01, 40: ~6.25
            elif adversarial_type == 'hop_skip':
                batch_t = hop_skip_jump_attack(networks[0], batch_t, np.inf, verbose=True)

            else:
                exit("Unrecognized adversarial attack type: " + str(adversarial_type))

        # Run each network and store output in 'out'
        out = torch.empty(
            (len(MODELS), batch_size, 1000))  # Shape [M, N, 1000] where M = num models, and N = batch size
        for i, net in enumerate(networks):
            if weights_BER != 0:  # If nonzero chance of weight bit flips
                net = flip_stochastic_bits_in_weights(weights_BER, net)
            out[i, :, :] = net(batch_t)

        predictions = vote(out, voting_heuristic)  # Returns predictions, with shape [N] (one prediction per image)
        num_correct = torch.sum(predictions == gt_labels).item()  # Item() pulls the integer out of the tensor

        total_correct += num_correct
        print("Batch %d:  %d / %d" % (batch_num, num_correct, batch_size))

    #################################################################################################
    #                                         Print Results:                                        #
    #################################################################################################

    percentage_correct = (total_correct / (batch_size * num_batches)) * 100
    if PRINT_OUT:
        print("Percentage Correct: %.2f%%" % percentage_correct)
        for i, net in enumerate(networks):
            print(MODELS[i] + str(':'))
            print("\t Total bit flips in weights:", get_flips_in_weights(net), "or %.0f per minute of inference"
                  % (get_flips_in_weights(net) / (num_batches / (32 * 60))))  # 32 batches/second (32 fps) * 60 seconds
            print("\t Total bit flips in activations:", get_flips_in_activations(net), "or %.0f per minute of inference"
                  % (get_flips_in_activations(net) / (num_batches / (32 * 60))))  # 32 batches/second (32 fps) * 60 seconds
            print("\t", stuck_at_faults, "out of", (get_num_params(net) * 32),
                  " weight bits permanently corrupted, or %.8f%%"
                  % ((stuck_at_faults / (get_num_params(net) * 32)) * 100))

    return [percentage_correct, get_num_weight_flips(), get_num_activation_flips()]
Exemple #6
0
# # 학습을 진행하지 않을 것이므로 torch.no_grad()
# with torch.no_grad():
#     X_test = test.test_data.view(len(test), 1, 28, 28).float().to(device)
#     Y_test = test.test_labels.to(device)

#     prediction = model(X_test)
#     correct_prediction = torch.argmax(prediction, 1) == Y_test
#     accuracy = correct_prediction.float().mean()
#     print('Accuracy:', accuracy.item())

report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)

for x, y in test_data_loader:
    x, y = x.to(device), y.to(device)

    x_fgm = fast_gradient_method(model, x, 0.5, np.inf)

    _, y_pred = model(x).max(1)

    _, y_pred_fgm = model(x_fgm).max(1)

    # report.nb_test += y.size(0)
    # report.correct += y_pred.eq(y).sum().item()
    # report.correct_fgm += y_pred_fgm.eq(y).sum().item()

    # print(
    #     "test acc on clean examples (%): {:.3f}".format(
    #         report.correct / report.nb_test * 100.0
    #     )
    # )
    # print(