def test_bpda_nograd_on_multi_input(device, func): class MultiInputFunc(nn.Module): def forward(self, x, y): return 2.0 * x - 1.0 * y class DummyNet(nn.Module): def __init__(self): super(DummyNet, self).__init__() self.linear = nn.Linear(1200, 10) def forward(self, x): x = x.view(x.shape[0], -1) return self.linear(x) bpda = BPDAWrapper(forward=MultiInputFunc()) with torch.enable_grad(): x = torch.rand(size=(10, 3, 20, 20), device=device, requires_grad=True) y = torch.rand_like(x, requires_grad=True) z = bpda(x, y) z_ = z.detach().requires_grad_() net = nn.Sequential(func, DummyNet()) with torch.enable_grad(): loss_ = net(z_).sum() loss = net(z).sum() grad_z, = torch.autograd.grad(loss_, [z_]) grad_x, grad_y = torch.autograd.grad(loss, [x, y]) assert torch_allclose(grad_x, grad_z) assert torch_allclose(grad_y, grad_z)
def _run_one_assert_val(ctxmgr, model, assert_inside, assert_outside): output = model(vecdata) assert_outside(model) with ctxmgr(model): assert_inside(model) assert torch_allclose(output, model(vecdata)) assert_outside(model) assert torch_allclose(output, model(vecdata))
def _run_one_assert_consistent(ctxmgr, model, get_state_fn, assert_inside): dct = get_state_fn(mix_model) output = model(vecdata) with ctxmgr(model): assert_inside(model) assert torch_allclose(output, model(vecdata)) newdct = get_state_fn(model) assert dct is not newdct assert dct == newdct assert torch_allclose(output, model(vecdata))
def test_bpda_on_withgrad_defense(device, def_cls): defense = def_cls(**defense_kwargs[def_cls]) grad_from_self = _calc_datagrad_on_defense(defense, defense_data[def_cls]) defense_with_idenity_backward = BPDAWrapper(defense, forwardsub=_identity) grad_from_identity_backward = _calc_datagrad_on_defense( defense_with_idenity_backward, defense_data[def_cls]) defense_with_self_backward = BPDAWrapper(defense, forwardsub=defense) grad_from_self_backward = _calc_datagrad_on_defense( defense_with_self_backward, defense_data[def_cls]) assert not torch_allclose(grad_from_identity_backward, grad_from_self) assert torch_allclose(grad_from_self_backward, grad_from_self)
def test_bpda_on_activations(device, func): data = vecdata.detach().clone() data = data - data.mean() grad_from_self = _calc_datagrad_on_defense(func, data) func_with_idenity_backward = BPDAWrapper(func, forwardsub=_identity) grad_from_identity_backward = _calc_datagrad_on_defense( func_with_idenity_backward, data) func_with_self_backward = BPDAWrapper(func, forwardsub=func) grad_from_self_backward = _calc_datagrad_on_defense( func_with_self_backward, data) assert not torch_allclose(grad_from_identity_backward, grad_from_self) assert torch_allclose(grad_from_self_backward, grad_from_self)
def test_cifar10_normalize(): # CIFAR10 tensor = torch.rand((16, 3, 32, 32)) normalize = NormalizeByChannelMeanStd(CIFAR10_MEAN, CIFAR10_STD) assert torch_allclose( torch.stack([ F.normalize(t, CIFAR10_MEAN, CIFAR10_STD) for t in tensor.clone() ]), normalize(tensor))
def test_mnist_normalize(): # MNIST tensor = torch.rand((16, 1, 28, 28)) normalize = NormalizeByChannelMeanStd(MNIST_MEAN, MNIST_STD) assert torch_allclose( torch.stack( [F.normalize(t, MNIST_MEAN, MNIST_STD) for t in tensor.clone()]), normalize(tensor))
def test_grad_through_normalize(): tensor = torch.rand((2, 1, 28, 28)) tensor.requires_grad_() mean = torch.tensor((0., )) std = torch.tensor((1., )) normalize = NormalizeByChannelMeanStd(mean, std) loss = (normalize(tensor)**2).sum() loss.backward() assert torch_allclose(2 * tensor, tensor.grad)
def _run_batch_consistent(data, label, model, att_cls, idx): if att_cls in feature_attacks: guide = data.detach().clone()[torch.randperm(len(data))] data, guide = data.to(cpu), guide.to(cpu) label_or_guide = guide else: label_or_guide = label model.to(cpu) data, label_or_guide = data.to(cpu), label_or_guide.to(cpu) adversary = att_cls(model, **attack_kwargs[att_cls]) assert torch_allclose( adversary.perturb(data, label_or_guide)[idx:idx + 1], adversary.perturb(data[idx:idx + 1], label_or_guide[idx:idx + 1]))
def _run_vec_eps_consistent(data, label, model, att_cls): if att_cls in feature_attacks: guide = data.detach().clone()[torch.randperm(len(data))] data, guide = data.to(cpu), guide.to(cpu) label_or_guide = guide else: label_or_guide = label model.to(cpu) data, label_or_guide = data.to(cpu), label_or_guide.to(cpu) adversary = att_cls(model, **attack_kwargs[att_cls]) torch.manual_seed(0) a = adversary.perturb(data, label_or_guide) _vec_ones = data.new_ones(size=(len(data),)) _mat_ones = torch.ones_like(data) adversary.eps = adversary.eps * _vec_ones if hasattr(adversary, "eps_iter"): adversary.eps_iter = adversary.eps_iter * _vec_ones adversary.clip_min = adversary.clip_min * _mat_ones adversary.clip_max = adversary.clip_max * _mat_ones torch.manual_seed(0) b = adversary.perturb(data, label_or_guide) assert torch_allclose(a, b)
def test_binary_filter(): assert torch_allclose(BinaryFilter()(data), data > 0.5)