def test_samples_fashionMNIST(): images, labels = samples(dataset='fashionMNIST', batchsize=5) assert 0 <= labels[0] < 10 assert images.shape[0] == 5 assert isinstance(labels[0], np.integer) assert images.shape == (5, 28, 28) assert images.dtype == np.float32
def test_samples_imagenet(): images, labels = samples(dataset='imagenet', batchsize=5) assert 0 <= labels[0] < 1000 assert images.shape[0] == 5 assert isinstance(labels[0], np.integer) assert images.shape == (5, 224, 224, 3) assert images.dtype == np.float32
def test_samples_cifar100(): images, labels = samples(dataset='cifar100', batchsize=5) assert 0 <= labels[0] < 100 assert images.shape[0] == 5 assert isinstance(labels[0], np.integer) assert images.shape == (5, 32, 32, 3) assert images.dtype == np.float32
def test_samples_imagenet_channels_first(): images, labels = samples(dataset='imagenet', batchsize=5, data_format='channels_first') assert 0 <= labels[0] < 1000 assert images.shape[0] == 5 assert isinstance(labels[0], np.integer) assert images.shape == (5, 3, 224, 224) assert images.dtype == np.float32
def test_samples_mnist(): images, labels = samples(dataset="mnist", batchsize=5) assert 0 <= labels[0] < 10 assert images.shape[0] == 5 assert isinstance(labels[0], np.integer) assert images.shape == (5, 28, 28) assert images.dtype == np.float32 assert 0 <= images.min() <= 5 assert 250 < images.max() <= 255
def test_samples_imagenet_bounds(): images, labels = samples(dataset="imagenet", batchsize=5, bounds=(0, 1)) assert 0 <= labels[0] < 1000 assert images.shape[0] == 5 assert isinstance(labels[0], np.integer) assert images.shape == (5, 224, 224, 3) assert images.dtype == np.float32 assert np.all(images >= 0) assert np.all(images <= 1)
def test_samples_mnist_bounds(): images, labels = samples(dataset="mnist", batchsize=5, bounds=(-1, 1)) assert 0 <= labels[0] < 10 assert images.shape[0] == 5 assert isinstance(labels[0], np.integer) assert images.shape == (5, 28, 28) assert images.dtype == np.float32 assert np.all(images >= -1) assert np.all(images <= 1) assert images.min() < 0 assert images.max() > 0
def create() -> PyTorchModel: model = nn.Sequential( nn.Conv2d(1, 32, 3), nn.ReLU(), nn.Conv2d(32, 64, 3), nn.ReLU(), nn.MaxPool2d(2), nn.Dropout2d(0.25), nn.Flatten(), # type: ignore nn.Linear(9216, 128), nn.ReLU(), nn.Dropout2d(0.5), nn.Linear(128, 10), ) path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mnist_cnn.pth") model.load_state_dict(torch.load(path)) # type: ignore model.eval() preprocessing = dict(mean=0.1307, std=0.3081) fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing) return fmodel if __name__ == "__main__": # test the model fmodel = create() images, labels = samples(fmodel, dataset="mnist", batchsize=20) print(accuracy(fmodel, images, labels))
print("Load Config") if args.config_path is not None: if not os.path.exists(args.config_path): raise ValueError("{} doesn't exist.".format(args.config_path)) config = json.load(open(args.config_path, "r")) else: config = {"init": {}, "run": {"epsilons": None}} ############################### print("Get understandable ImageNet labels") imagenet_labels = get_imagenet_labels() ############################### print("Load Data") images, labels = samples(fmodel, dataset="imagenet", batchsize=args.n_images) print("{} images loaded with the following labels: {}".format( len(images), labels)) ############################### print("Attack !") time_start = time.time() f_attack = SurFree(**config["init"]) elements, advs, success = f_attack(fmodel, images, labels, **config["run"]) print("{:.2f} s to run".format(time.time() - time_start)) ############################### print("Results")