Ejemplo n.º 1
0
    def _pgd_whitebox(X, y, mean, std):
        X_pgd = X.clone()
        if args.random:
            X_pgd += torch.cuda.FloatTensor(*X_pgd.shape).uniform_(
                -args.epsilon, args.epsilon)

        for _ in range(args.num_steps):
            grad_ = _grad(X_pgd, y, mean, std)
            X_pgd += args.step_size * grad_.sign()
            eta = torch.clamp(X_pgd - X, -args.epsilon, args.epsilon)
            X_pgd = torch.clamp(X + eta, 0, 1.0)

        mis = 0
        preds = 0
        for ens in range(num_ens):
            curvature.sample_and_replace_weights(model, inv_factors, "diag")
            output = model(X_pgd.sub(mean).div(std))
            model.load_state_dict(posterior_mean)
            mis = (mis * ens + (-output.softmax(-1) *
                                (output).log_softmax(-1)).sum(1)) / (ens + 1)
            preds = (preds * ens + output.softmax(-1)) / (ens + 1)

        loss = criterion((preds + 1e-8).log(), target)
        prec1, prec5 = accuracy(preds, target, topk=(1, 5))
        mis = (-preds *
               (preds + 1e-8).log()).sum(1) - (0 if num_ens == 1 else mis)
        return loss, prec1, prec5, mis
Ejemplo n.º 2
0
 def _grad(X, y, mean, std):
     probs = torch.zeros(num_ens, X.shape[0]).cuda(args.gpu)
     grads = torch.zeros(num_ens, *list(X.shape)).cuda(args.gpu)
     for j in range(num_ens):
         with torch.enable_grad():
             X.requires_grad_()
             curvature.sample_and_replace_weights(model, inv_factors,
                                                  "diag")
             output = model(X.sub(mean).div(std))
             model.load_state_dict(posterior_mean)
             loss = torch.nn.functional.cross_entropy(output,
                                                      y,
                                                      reduction='none')
             grad_ = torch.autograd.grad([loss], [X],
                                         grad_outputs=torch.ones_like(loss),
                                         retain_graph=False)[0].detach()
         grads[j] = grad_
         probs[j] = torch.gather(output.detach().softmax(-1), 1,
                                 y[:, None]).squeeze()
     probs /= probs.sum(0)
     grad_ = (grads * probs[:, :, None, None, None]).sum(0)
     return grad_
Ejemplo n.º 3
0
def ens_validate(val_loader,
                 model,
                 criterion,
                 inv_factors,
                 args,
                 log,
                 num_ens=20,
                 suffix=''):
    model.eval()
    if args.dropout_rate > 0.:
        for m in model.modules():
            if m.__class__.__name__.startswith('Dropout'): m.train()

    posterior_mean = copy.deepcopy(model.state_dict())

    ece_func = _ECELoss().cuda(args.gpu)
    with torch.no_grad():
        targets = []
        mis = [0 for _ in range(len(val_loader))]
        preds = [0 for _ in range(len(val_loader))]
        rets = torch.zeros(num_ens, 9).cuda(args.gpu)

        for ens in range(num_ens):
            curvature.sample_and_replace_weights(model, inv_factors, "diag")
            for i, (input, target) in enumerate(val_loader):
                input = input.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)
                if ens == 0: targets.append(target)

                output = model(input)

                one_loss = criterion(output, target)

                # print(ens, i, one_loss.item())
                one_prec1, one_prec5 = accuracy(output, target, topk=(1, 5))

                mis[i] = (mis[i] * ens +
                          (-output.softmax(-1) *
                           output.log_softmax(-1)).sum(1)) / (ens + 1)
                preds[i] = (preds[i] * ens + output.softmax(-1)) / (ens + 1)

                loss = criterion(preds[i].log(), target)
                prec1, prec5 = accuracy(preds[i], target, topk=(1, 5))

                rets[ens, 0] += ens * target.size(0)
                rets[ens, 1] += one_loss.item() * target.size(0)
                rets[ens, 2] += one_prec1.item() * target.size(0)
                rets[ens, 3] += one_prec5.item() * target.size(0)
                rets[ens, 5] += loss.item() * target.size(0)
                rets[ens, 6] += prec1.item() * target.size(0)
                rets[ens, 7] += prec5.item() * target.size(0)

            model.load_state_dict(posterior_mean)

        preds = torch.cat(preds, 0)

        # to sync
        confidences, predictions = torch.max(preds, 1)
        targets = torch.cat(targets, 0)
        mis = (-preds *
               preds.log()).sum(1) - (0 if num_ens == 1 else torch.cat(mis, 0))
        rets /= targets.size(0)

        rets = rets.data.cpu().numpy()
        if suffix == '':
            ens_ece = ece_func(
                confidences, predictions, targets,
                os.path.join(args.save_path, 'ens_cal{}.pdf'.format(suffix)))
            rets[-1, -1] = ens_ece

    if args.gpu == 0:
        np.save(os.path.join(args.save_path, 'mis{}.npy'.format(suffix)),
                mis.data.cpu().numpy())
    return rets