Esempio n. 1
0
def test_samples_fashionMNIST():
    images, labels = samples(dataset='fashionMNIST', batchsize=5)
    assert 0 <= labels[0] < 10
    assert images.shape[0] == 5
    assert isinstance(labels[0], np.integer)
    assert images.shape == (5, 28, 28)
    assert images.dtype == np.float32
Esempio n. 2
0
def test_samples_imagenet():
    images, labels = samples(dataset='imagenet', batchsize=5)
    assert 0 <= labels[0] < 1000
    assert images.shape[0] == 5
    assert isinstance(labels[0], np.integer)
    assert images.shape == (5, 224, 224, 3)
    assert images.dtype == np.float32
Esempio n. 3
0
def test_samples_cifar100():
    images, labels = samples(dataset='cifar100', batchsize=5)
    assert 0 <= labels[0] < 100
    assert images.shape[0] == 5
    assert isinstance(labels[0], np.integer)
    assert images.shape == (5, 32, 32, 3)
    assert images.dtype == np.float32
Esempio n. 4
0
def test_samples_imagenet_channels_first():
    images, labels = samples(dataset='imagenet',
                             batchsize=5,
                             data_format='channels_first')
    assert 0 <= labels[0] < 1000
    assert images.shape[0] == 5
    assert isinstance(labels[0], np.integer)
    assert images.shape == (5, 3, 224, 224)
    assert images.dtype == np.float32
Esempio n. 5
0
def test_samples_mnist():
    images, labels = samples(dataset="mnist", batchsize=5)
    assert 0 <= labels[0] < 10
    assert images.shape[0] == 5
    assert isinstance(labels[0], np.integer)
    assert images.shape == (5, 28, 28)
    assert images.dtype == np.float32
    assert 0 <= images.min() <= 5
    assert 250 < images.max() <= 255
Esempio n. 6
0
def test_samples_imagenet_bounds():
    images, labels = samples(dataset="imagenet", batchsize=5, bounds=(0, 1))
    assert 0 <= labels[0] < 1000
    assert images.shape[0] == 5
    assert isinstance(labels[0], np.integer)
    assert images.shape == (5, 224, 224, 3)
    assert images.dtype == np.float32
    assert np.all(images >= 0)
    assert np.all(images <= 1)
Esempio n. 7
0
def test_samples_mnist_bounds():
    images, labels = samples(dataset="mnist", batchsize=5, bounds=(-1, 1))
    assert 0 <= labels[0] < 10
    assert images.shape[0] == 5
    assert isinstance(labels[0], np.integer)
    assert images.shape == (5, 28, 28)
    assert images.dtype == np.float32
    assert np.all(images >= -1)
    assert np.all(images <= 1)
    assert images.min() < 0
    assert images.max() > 0
Esempio n. 8
0

def create() -> PyTorchModel:
    model = nn.Sequential(
        nn.Conv2d(1, 32, 3),
        nn.ReLU(),
        nn.Conv2d(32, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2),
        nn.Dropout2d(0.25),
        nn.Flatten(),  # type: ignore
        nn.Linear(9216, 128),
        nn.ReLU(),
        nn.Dropout2d(0.5),
        nn.Linear(128, 10),
    )
    path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                        "mnist_cnn.pth")
    model.load_state_dict(torch.load(path))  # type: ignore
    model.eval()
    preprocessing = dict(mean=0.1307, std=0.3081)
    fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)
    return fmodel


if __name__ == "__main__":
    # test the model
    fmodel = create()
    images, labels = samples(fmodel, dataset="mnist", batchsize=20)
    print(accuracy(fmodel, images, labels))
Esempio n. 9
0
    print("Load Config")
    if args.config_path is not None:
        if not os.path.exists(args.config_path):
            raise ValueError("{} doesn't exist.".format(args.config_path))
        config = json.load(open(args.config_path, "r"))
    else:
        config = {"init": {}, "run": {"epsilons": None}}

    ###############################
    print("Get understandable ImageNet labels")
    imagenet_labels = get_imagenet_labels()

    ###############################
    print("Load Data")
    images, labels = samples(fmodel,
                             dataset="imagenet",
                             batchsize=args.n_images)
    print("{} images loaded with the following labels: {}".format(
        len(images), labels))

    ###############################
    print("Attack !")
    time_start = time.time()

    f_attack = SurFree(**config["init"])

    elements, advs, success = f_attack(fmodel, images, labels, **config["run"])
    print("{:.2f} s to run".format(time.time() - time_start))

    ###############################
    print("Results")