Exemple #1
0
    def _pgd_whitebox(X, mean, std):
        freeze(model)
        y = model(X.sub(mean).div(std), True).reshape(X.size(0) // 2, 2, -1)
        unfreeze(model)

        X_pgd = X.clone()
        if args.random:
            X_pgd += torch.cuda.FloatTensor(*X_pgd.shape).uniform_(
                -args.epsilon, args.epsilon)

        for _ in range(args.num_steps):
            grad_ = _grad(X_pgd, y, mean, std)
            X_pgd += args.step_size * grad_.sign()
            eta = torch.clamp(X_pgd - X, -args.epsilon, args.epsilon)
            X_pgd = torch.clamp(X + eta, 0, 1.0)

        mis = 0
        preds = 0
        embedding_b = 0
        for ens in range(num_ens):
            output, output_logits = model(X_pgd.sub(mean).div(std),
                                          return_both=True)
            embedding_b += output / num_ens
            mis = (mis * ens +
                   (-output_logits.softmax(-1) *
                    (output_logits).log_softmax(-1)).sum(1)) / (ens + 1)
            preds = (preds * ens + output_logits.softmax(-1)) / (ens + 1)

        norm = torch.norm(embedding_b, 2, 1, True)
        embedding = torch.div(embedding_b, norm)
        mis = (-preds *
               (preds + 1e-8).log()).sum(1) - (0 if num_ens == 1 else mis)
        return embedding, mis
Exemple #2
0
def evaluate(val_loaders, fake_loader, net, criterion, args, log,
             num_mc_samples, num_mc_samples2):
    freeze(net)
    if args.gpu == 0:
        print("-----------------deterministic-----------------")
    deter_rets = ens_validate(val_loaders, net, criterion, args, log, 1)
    unfreeze(net)

    if args.gpu == 0:
        print("-----------------ensemble {} times-----------------".format(
            num_mc_samples2))
    rets = ens_validate(val_loaders, net, criterion, args, log,
                        num_mc_samples2)

    ens_attack(val_loaders, net, criterion, args, log, num_mc_samples,
               min(num_mc_samples, 8))
    if args.gpu == 0:
        for k in val_loaders:
            print_log(
                '{} vs. adversarial: AP {}'.format(
                    k[0], plot_mi(args.save_path, 'adv_' + k[0], k[0])), log)

    ens_validate(fake_loader,
                 net,
                 criterion,
                 args,
                 log,
                 num_mc_samples,
                 suffix='fake')
    if args.gpu == 0:
        for k in val_loaders:
            print_log(
                '{} vs. DeepFake: AP {}'.format(
                    k[0], plot_mi(args.save_path, 'fake', k[0])), log)
Exemple #3
0
def evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion,
             args, log, num_mc_samples, num_mc_samples2):
    freeze(net)
    deter_rets = ens_validate(test_loader, net, criterion, args, log, 1)
    unfreeze(net)

    rets = ens_validate(test_loader, net, criterion, args, log,
                        num_mc_samples2)
    print_log(
        'TOP1 average: {:.4f}, ensemble: {:.4f}, deter: {:.4f}'.format(
            rets[:, 2].mean(), rets[-1][-3], deter_rets[0][2]), log)
    print_log(
        'TOP5 average: {:.4f}, ensemble: {:.4f}, deter: {:.4f}'.format(
            rets[:, 3].mean(), rets[-1][-2], deter_rets[0][3]), log)
    print_log(
        'LOS  average: {:.4f}, ensemble: {:.4f}, deter: {:.4f}'.format(
            rets[:, 1].mean(), rets[-1][-4], deter_rets[0][1]), log)
    print_log(
        'ECE  ensemble: {:.4f}, deter: {:.4f}'.format(rets[-1][-1],
                                                      deter_rets[-1][-1]), log)
    if args.gpu == 0: plot_ens(args.save_path, rets, deter_rets[0][2])

    ens_attack(adv_loader, net, criterion, args, log, num_mc_samples)
    if args.gpu == 0:
        print_log('NAT vs. ADV: AP {}'.format(plot_mi(args.save_path, 'advg')),
                  log)

    ens_validate(fake_loader,
                 net,
                 criterion,
                 args,
                 log,
                 num_mc_samples,
                 suffix='_fake')
    if args.gpu == 0:
        print_log(
            'NAT vs. Fake (BigGAN): AP {}'.format(
                plot_mi(args.save_path, 'fake')), log)

    ens_validate(adv_loader2,
                 net,
                 criterion,
                 args,
                 log,
                 num_mc_samples,
                 suffix='_adv')
    if args.gpu == 0:
        print_log(
            'NAT vs. FGSM-ResNet152: AP {}'.format(
                plot_mi(args.save_path, 'adv')), log)
    return rets[-1][-3], rets[-1][-4]
Exemple #4
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc
    args.gpu = gpu
    assert args.gpu is not None
    print("Use GPU: {} for training".format(args.gpu))

    log = open(
        os.path.join(
            args.save_path,
            'log_seed{}{}.txt'.format(args.manualSeed,
                                      '_eval' if args.evaluate else '')), 'w')
    log = (log, args.gpu)

    net = models.__dict__[args.arch](pretrained=True)
    disable_dropout(net)
    net = to_bayesian(net, args.psi_init_range)
    unfreeze(net)

    print_log("Python version : {}".format(sys.version.replace('\n', ' ')),
              log)
    print_log("PyTorch  version : {}".format(torch.__version__), log)
    print_log("CuDNN  version : {}".format(torch.backends.cudnn.version()),
              log)
    print_log(
        "Number of parameters: {}".format(
            sum([p.numel() for p in net.parameters()])), log)
    print_log(str(args), log)

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url + ":" +
                                args.dist_port,
                                world_size=args.world_size,
                                rank=args.rank)
        torch.cuda.set_device(args.gpu)
        net.cuda(args.gpu)
        args.batch_size = int(args.batch_size / ngpus_per_node)
        net = torch.nn.parallel.DistributedDataParallel(net,
                                                        device_ids=[args.gpu])
    else:
        torch.cuda.set_device(args.gpu)
        net = net.cuda(args.gpu)

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    mus, psis = [], []
    for name, param in net.named_parameters():
        if 'psi' in name: psis.append(param)
        else: mus.append(param)
    mu_optimizer = SGD(mus,
                       args.learning_rate,
                       args.momentum,
                       weight_decay=args.decay)

    psi_optimizer = PsiSGD(psis,
                           args.learning_rate,
                           args.momentum,
                           weight_decay=args.decay)

    recorder = RecorderMeter(args.epochs)
    if args.resume:
        if args.resume == 'auto':
            args.resume = os.path.join(args.save_path, 'checkpoint.pth.tar')
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume,
                                    map_location='cuda:{}'.format(args.gpu))
            recorder = checkpoint['recorder']
            recorder.refresh(args.epochs)
            args.start_epoch = checkpoint['epoch']
            net.load_state_dict(
                checkpoint['state_dict'] if args.distributed else {
                    k.replace('module.', ''): v
                    for k, v in checkpoint['state_dict'].items()
                })
            mu_optimizer.load_state_dict(checkpoint['mu_optimizer'])
            psi_optimizer.load_state_dict(checkpoint['psi_optimizer'])
            best_acc = recorder.max_accuracy(False)
            print_log(
                "=> loaded checkpoint '{}' accuracy={} (epoch {})".format(
                    args.resume, best_acc, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume),
                      log)
    else:
        print_log("=> do not use any checkpoint for the model", log)

    cudnn.benchmark = True

    train_loader, ood_train_loader, test_loader, adv_loader, \
        fake_loader, adv_loader2 = load_dataset_ft(args)
    psi_optimizer.num_data = len(train_loader.dataset)

    if args.evaluate:
        evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net,
                 criterion, args, log, 20, 100)
        return

    start_time = time.time()
    epoch_time = AverageMeter()
    train_los = -1

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
            ood_train_loader.sampler.set_epoch(epoch)
        cur_lr, cur_slr = adjust_learning_rate(mu_optimizer, psi_optimizer,
                                               epoch, args)

        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f} {:6.4f}]'.format(
                                    time_string(), epoch, args.epochs, need_time, cur_lr, cur_slr) \
                    + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        train_acc, train_los = train(train_loader, ood_train_loader, net,
                                     criterion, mu_optimizer, psi_optimizer,
                                     epoch, args, log)
        val_acc, val_los = 0, 0
        recorder.update(epoch, train_los, train_acc, val_acc, val_los)

        is_best = False
        if val_acc > best_acc:
            is_best = True
            best_acc = val_acc

        if args.gpu == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': net.state_dict(),
                    'recorder': recorder,
                    'mu_optimizer': mu_optimizer.state_dict(),
                    'psi_optimizer': psi_optimizer.state_dict(),
                }, False, args.save_path, 'checkpoint.pth.tar')

        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'log.png'))

    evaluate(test_loader, adv_loader, fake_loader, adv_loader2, net, criterion,
             args, log, 20, 100)

    log[0].close()
Exemple #5
0
    args.cutout = True
    args.distributed = False
    args.batch_size = 32
    args.workers = 4
    train_loader, test_loader = load_dataset(args)

    net = wrn(pretrained=True, depth=28, width=10).cuda()
    disable_dropout(net)

    eval_loss, eval_acc = Bayes_ensemble(test_loader, net,
                                         num_mc_samples=1)
    print('Results of deterministic pre-training, '
          'eval loss {}, eval acc {}'.format(eval_loss, eval_acc))

    bayesian_net = to_bayesian(net)
    unfreeze(bayesian_net)

    mus, psis = [], []
    for name, param in bayesian_net.named_parameters():
        if 'psi' in name: psis.append(param)
        else: mus.append(param)
    mu_optimizer = SGD(mus, lr=0.0008, momentum=0.9,
                       weight_decay=2e-4, nesterov=True)
    psi_optimizer = PsiSGD(psis, lr=0.1, momentum=0.9,
                           weight_decay=2e-4, nesterov=True,
                           num_data=50000)

    for epoch in range(args.epochs):
        bayesian_net.train()
        for i, (input, target) in enumerate(train_loader):
            input = input.cuda(non_blocking=True)