Beispiel #1
0
    def __init__(self, args):
        self.args = args

        # setup data_loader instances
        self.data_loader = get_mnist_test_loader(100, shuffle=False)

        # setup device
        self.device = torch.device('cuda:' + args.device)

        # build defense architecture
        bits_squeezing = BitSqueezing(bit_depth=5)
        median_filter = MedianSmoothing2D(kernel_size=3)
        jpeg_filter = JPEGFilter(10)

        self.defense = nn.Sequential(
            jpeg_filter,
            bits_squeezing,
            median_filter,
        )
        # build classifier architecture
        self.oracle = LeNet()
        self.oracle = self.oracle.to(self.device)
        self.oracle.load_state_dict(torch.load(args.resume_oracle))
        self.oracle.eval()

        self.adversary = BPDAattack(self.oracle, self.defense, self.device,
                                    epsilon=0.3,
                                    learning_rate=0.5,
                                    max_iterations=100)
Beispiel #2
0
def test_median_filter():
    # XXX: doesn't pass when kernel_size is even
    # XXX: when kernel_size is odd, pixels on the boundaries are different
    kernel_size = 3
    padding = kernel_size // 2
    rval_scipy = ndimage.filters.median_filter(data.detach().numpy(),
                                               size=(1, 1, kernel_size,
                                                     kernel_size))
    rval = MedianSmoothing2D(kernel_size=kernel_size)(data).detach().numpy()
    assert np.allclose(rval_scipy[:, :, padding:-padding, padding:-padding],
                       rval[:, :, padding:-padding, padding:-padding])
Beispiel #3
0
    torch.load(os.path.join( filename),map_location=torch.device('cpu')))
model.to(device)
model.eval()

batch_size = 100
loader = get_mnist_test_loader(batch_size=batch_size)
for cln_data, true_label in loader:
    break
cln_data, true_label = cln_data.to(device), true_label.to(device)

from advertorch.defenses import MedianSmoothing2D
from advertorch.defenses import BitSqueezing
from advertorch.defenses import JPEGFilter

bits_squeezing = BitSqueezing(bit_depth=5)
median_filter = MedianSmoothing2D(kernel_size=3)
jpeg_filter = JPEGFilter(10)

defense = nn.Sequential(
    jpeg_filter,
    bits_squeezing,
    median_filter,
)
from advertorch.attacks import LocalSearchAttack
from advertorch.bpda import BPDAWrapper
#defense_withbpda = BPDAWrapper(defense, forwardsub=lambda x: x)
#defended_model = nn.Sequential(defense_withbpda, model)
from advertorch.attacks import SinglePixelAttack
bpda_adversary = LocalSearchAttack(
    model, loss_fn=nn.CrossEntropyLoss(reduction="sum")#, num_classes=10
    )
Beispiel #4
0
                              clip_max=1.0,
                              targeted=False)
else:
    #adversary = L2PGDAttack(net, loss_fn=csl, eps=args.eps, nb_iter=10, eps_iter=10.0,rand_init=False, clip_min=0.0, clip_max=255., targeted=False)
    #adversary = PGDAttack(net, loss_fn=csl, eps=args.eps, nb_iter=10, eps_iter=1., rand_init=False, clip_min=0.0,clip_max=255., targeted=False)
    adversary = L2BasicIterativeAttack(net,
                                       loss_fn=csl,
                                       eps=args.eps,
                                       nb_iter=10,
                                       eps_iter=10.0,
                                       clip_min=0.0,
                                       clip_max=255.,
                                       targeted=False)
    #adversary =MomentumIterativeAttack(net, loss_fn=csl, eps=args.eps, nb_iter=10, eps_iter=1, clip_min=0.0, clip_max=255., targeted=False)
    # adversary = JSMA(net, num_classes = 24, clip_min=0.0, clip_max=255.,theta=1.0, gamma=1.0,loss_fn=nn.CrossEntropyLoss())
defensed = MedianSmoothing2D()
#defensed2 = AverageSmoothing2D(3,3)
#defensed3 = GaussianSmoothing2D(0.1,3,3)
#defensed2 = AverageSmoothing2D(3,3)
defensed2 = JPEGFilter()
net.eval()
# net2.eval()
# net3.eval()
# net4.eval()
# net5.eval()
correct = 0
total = 0
if args.attack_for_twoc:
    from dataload_test import load_test_list, load_test_list2, get_test, get_test2
    test_num = load_test_list()
    test_num2 = load_test_list2()