def test_bpda_on_nograd_defense(device, def_cls): defense = def_cls(**defense_kwargs[def_cls]) defense = BPDAWrapper(defense, forwardsub=_identity) _calc_datagrad_on_defense(defense, defense_data[def_cls]) defense = BPDAWrapper(defense, backward=_straight_through_backward) _calc_datagrad_on_defense(defense, defense_data[def_cls])
def test_bpda_on_withgrad_defense(device, def_cls): defense = def_cls(**defense_kwargs[def_cls]) grad_from_self = _calc_datagrad_on_defense(defense, defense_data[def_cls]) defense_with_idenity_backward = BPDAWrapper(defense, forwardsub=_identity) grad_from_identity_backward = _calc_datagrad_on_defense( defense_with_idenity_backward, defense_data[def_cls]) defense_with_self_backward = BPDAWrapper(defense, forwardsub=defense) grad_from_self_backward = _calc_datagrad_on_defense( defense_with_self_backward, defense_data[def_cls]) assert not torch_allclose(grad_from_identity_backward, grad_from_self) assert torch_allclose(grad_from_self_backward, grad_from_self)
def test_bpda_on_activations(device, func): data = vecdata.detach().clone() data = data - data.mean() grad_from_self = _calc_datagrad_on_defense(func, data) func_with_idenity_backward = BPDAWrapper(func, forwardsub=_identity) grad_from_identity_backward = _calc_datagrad_on_defense( func_with_idenity_backward, data) func_with_self_backward = BPDAWrapper(func, forwardsub=func) grad_from_self_backward = _calc_datagrad_on_defense( func_with_self_backward, data) assert not torch_allclose(grad_from_identity_backward, grad_from_self) assert torch_allclose(grad_from_self_backward, grad_from_self)
def test_bpda_nograd_on_multi_input(device, func): class MultiInputFunc(nn.Module): def forward(self, x, y): return 2.0 * x - 1.0 * y class DummyNet(nn.Module): def __init__(self): super(DummyNet, self).__init__() self.linear = nn.Linear(1200, 10) def forward(self, x): x = x.view(x.shape[0], -1) return self.linear(x) bpda = BPDAWrapper(forward=MultiInputFunc()) with torch.enable_grad(): x = torch.rand(size=(10, 3, 20, 20), device=device, requires_grad=True) y = torch.rand_like(x, requires_grad=True) z = bpda(x, y) z_ = z.detach().requires_grad_() net = nn.Sequential(func, DummyNet()) with torch.enable_grad(): loss_ = net(z_).sum() loss = net(z).sum() grad_z, = torch.autograd.grad(loss_, [z_]) grad_x, grad_y = torch.autograd.grad(loss, [x, y]) assert torch_allclose(grad_x, grad_z) assert torch_allclose(grad_y, grad_z)
def get_madry_et_al_tf_model(dataname, weights_path, device="cuda"): if dataname == "mnist": try: from .mnist_challenge.model import Model print("mnist_challenge found and imported") except (ImportError, ModuleNotFoundError): print("mnist_challenge not found, downloading ...") os.system("bash download_mnist_challenge.sh {}".format(MODEL_PATH)) from .mnist_challenge.model import Model print("mnist_challenge found and imported") def _process_inputs_val(val): return val.view(val.shape[0], 784) def _process_grads_val(val): return val.view(val.shape[0], 1, 28, 28) elif dataname == "cifar": try: from .cifar10_challenge.model import Model print("cifar10_challenge found and imported") except (ImportError, ModuleNotFoundError): print("cifar10_challenge not found, downloading ...") os.system( "bash download_cifar10_challenge.sh {}".format(MODEL_PATH)) from .cifar10_challenge.model import Model print("cifar10_challenge found and imported") from functools import partial Model = partial(Model, mode="eval") def _process_inputs_val(val): return 255. * val.permute(0, 2, 3, 1) def _process_grads_val(val): return val.permute(0, 3, 1, 2) / 255. else: raise ValueError(dataname) def _wrap_forward(forward): def new_forward(inputs_val): return forward(_process_inputs_val(inputs_val)) return new_forward def _wrap_backward(backward): def new_backward(inputs_val, logits_grad_val): return _process_grads_val( backward(_process_inputs_val(*inputs_val), *logits_grad_val)) return new_backward ptmodel = TorchWrappedModel(WrappedTfModel(weights_path, Model), device) model = BPDAWrapper(forward=_wrap_forward(ptmodel.forward), backward=_wrap_backward(ptmodel.backward)) return model
def init_advertorch(self, model, device, attack_params, dataset_params): mean = dataset_params['mean'] std = dataset_params['std'] num_classes = dataset_params['num_classes'] self.normalize = NormalizeByChannelMeanStd(mean=mean, std=std) basic_model = model if (attack_params['bpda'] == True): preprocess = attack_params['preprocess'] preprocess_bpda_wrapper = BPDAWrapper( preprocess, forwardsub=preprocess.back_approx) attack_model = nn.Sequential(self.normalize, preprocess_bpda_wrapper, basic_model).to(device) else: attack_model = nn.Sequential(self.normalize, basic_model).to(device) attack_name = attack_params['attack'].lower() if (attack_name == 'pgd'): iterations = attack_params['iterations'] stepsize = attack_params['stepsize'] epsilon = attack_params['epsilon'] attack = advertorch.attacks.LinfPGDAttack random = attack_params['random'] # Return attack dictionary return { 'attack': attack, 'iterations': iterations, 'epsilon': epsilon, 'stepsize': stepsize, 'model': attack_model, 'random': random } elif (attack_name == 'cw'): iterations = attack_params['iterations'] epsilon = attack_params['epsilon'] attack = advertorch.attacks.CarliniWagnerL2Attack # Return attack dictionary return { 'attack': attack, 'iterations': iterations, 'epsilon': epsilon, 'model': attack_model, 'num_classes': num_classes } else: # Right way to handle exception in python see https://stackoverflow.com/questions/2052390/manually-raising-throwing-an-exception-in-python # Explains all the traps of using exception, does a good job!! I mean the link :) raise ValueError("Unsupported attack")
from advertorch.defenses import MedianSmoothing2D from advertorch.defenses import BitSqueezing from advertorch.defenses import JPEGFilter bits_squeezing = BitSqueezing(bit_depth=5) median_filter = MedianSmoothing2D(kernel_size=3) jpeg_filter = JPEGFilter(10) defense = nn.Sequential( jpeg_filter, bits_squeezing, median_filter, ) from advertorch.attacks import LBFGSAttack from advertorch.bpda import BPDAWrapper defense_withbpda = BPDAWrapper(defense, forwardsub=lambda x: x) defended_model = nn.Sequential(defense_withbpda, model) bpda_adversary = LBFGSAttack( model, loss_fn=nn.CrossEntropyLoss(reduction="sum"),num_classes=10, targeted=False) bpda_adv = bpda_adversary.perturb(cln_data, true_label) import matplotlib.pyplot as plt import numpy as np plt.figure(figsize=(10, 8)) for ii in range(batch_size): plt.subplot(3, batch_size, ii + 1) _imshow(cln_data[ii]) plt.subplot(3, batch_size, ii + 1 + batch_size)