def AOA_Attack(model, x_adv, label, target): clip_func = ClipPointsLinf(budget=0.08) x_ori = x_adv.clone().detach() R_ori = LRP_scores(model, x_ori, label, label).detach() sorted_R, ids = torch.sort(R_ori.sum(-2), dim=-1, descending=True) bids = torch.arange(0, x_adv.shape[0]).view(-1, 1) W = torch.zeros_like(R_ori) W[bids, :, ids[:, 0:100]] = 1.0 # W[:, :, ids[-200:]] = -1.0 alpha = 0.066 iter_k = 0 # torch.autograd.set_detect_anomaly(True) g = torch.zeros_like(x_adv) while iter_k < 150: pred, _, _ = model(x_adv) #pred_at = model_at(x_adv) #lsec = torch.topk(pred, k=2, dim=-1)[1].squeeze() #print(pred[bids.squeeze(), label].detach().cpu().numpy(), pred[bids.squeeze(), lsec[:, 1]].detach().cpu().numpy()) # R_ori = LRP_scores(model, x_ori, label, label) R_sec = LRP_scores(model, x_adv, label, label) if iter_k > 10: # log_loss = (W*R_ori).sum() # log_loss = torch.log(R_ori.abs().sum()) - torch.log(R_sec.abs().sum()) W = 12.5 * (0.08 - (x_adv - x_ori).abs()) log_loss = -torch.log( (W * (R_ori - R_sec).abs()).sum()) - 0.5 * torch.log( (x_adv - x_ori).abs().sum( )) # + 0.01 * LogitsAdvLoss(0.0)(pred, target).mean() else: log_loss = LogitsAdvLoss(0.0)(pred, target).mean() # log_loss = R_sec.sum() # - R_sec.sum() # ce_loss = nn.CrossEntropyLoss()(pred, target) # log_loss = LogitsAdvLoss(0.0)(pred, target).mean() loss = log_loss # + lmd * ce_loss loss.backward() # pred = torch.argmax(pred, dim=-1) # if pred != label: # break gt = x_adv.grad.detach() nan_w = ~torch.isnan(gt) gt = gt * nan_w norm = torch.sum(gt**2, dim=[1, 2])**0.5 gt = gt / (norm[:, None, None] + 1e-12) # gt = gt / (gt.abs().sum()/N) # print(gt) g = 0.9 * g + gt perturbation = alpha * gt x_adv = x_adv - perturbation # x_adv = x_ori + (x_adv-x_ori).clamp(min=-eps, max=eps) # x_adv = x_adv.clamp(min=-1, max=1) x_adv = clip_func(x_adv, x_ori) x_adv = (x_adv.data).requires_grad_(True) iter_k += 1 # if iter_k == 11: # alpha = 0.816 # # print(loss) print("iter time is%d" % iter_k) return x_adv
print('Loading weight {}'.format(BEST_WEIGHTS[args.model])) try: model.load_state_dict(state_dict) except RuntimeError: # eliminate 'module.' in keys state_dict = {k[7:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) # distributed mode on multiple GPUs! # much faster than nn.DataParallel model = DistributedDataParallel( model.cuda(), device_ids=[args.local_rank]) # setup attack settings if args.adv_func == 'logits': adv_func = LogitsAdvLoss(kappa=args.kappa) else: adv_func = CrossEntropyAdvLoss() dist_func = L2Dist() # hyper-parameters from their official tensorflow code attacker = CWPerturb(model, adv_func, dist_func, attack_lr=args.attack_lr, init_weight=10., max_weight=80., binary_step=args.binary_step, num_iter=args.num_iter) # attack test_set = ModelNet40Attack(args.data_root, num_points=args.num_points, normalize=True) test_sampler = DistributedSampler(test_set, shuffle=False) test_loader = DataLoader(test_set, batch_size=args.batch_size,