예제 #1
0
def test_can_patch_mnist_dataset():
    dataset = datasets.mnist()

    a_patch = patch.Patch('block',
                          proportion=0.05,
                          input_shape=dataset.input_shape,
                          dynamic_mask=False,
                          dynamic_pattern=False)
    objective = util.make_objective_class(0, dataset.num_classes)

    objective = util.make_objective_class(0, dataset.num_classes)
    fraction = 0.1

    poisoned = dataset.poison(objective, a_patch, fraction)

    train_raw_poisoned = dataset.x_train[poisoned.train_poisoned_idx]
    train_patched_poisoned = poisoned.x_train[poisoned.train_poisoned_idx]

    # training set: verify that the indexes that are supposed to be patched
    # indeed have the patch
    _ = a_patch.apply(train_raw_poisoned)
    np.testing.assert_array_equal(_, train_patched_poisoned)

    test_raw_poisoned = dataset.x_test[poisoned.test_poisoned_idx]
    test_patched_poisoned = poisoned.x_test[poisoned.test_poisoned_idx]

    # test set: verify that the indexes that are supposed to be patched
    # indeed have the patch
    _ = a_patch.apply(test_raw_poisoned)
    np.testing.assert_array_equal(_, test_patched_poisoned)

    # training set: verify that the indexes that are NOT supposed to be patched
    # indeed DO NOT have the patch
    train_raw_nonpoisoned = dataset.x_train[~poisoned.train_poisoned_idx]
    train_patched_nonpoisoned = poisoned.x_train[~poisoned.train_poisoned_idx]

    np.testing.assert_array_equal(train_raw_nonpoisoned,
                                  train_patched_nonpoisoned)

    # test set: verify that the indexes that are NOT supposed to be patched
    # indeed DO NOT have the patch
    test_raw_nonpoisoned = dataset.x_test[~poisoned.test_poisoned_idx]
    test_patched_nonpoisoned = poisoned.x_test[~poisoned.test_poisoned_idx]

    np.testing.assert_array_equal(test_raw_nonpoisoned,
                                  test_patched_nonpoisoned)
예제 #2
0
def test_can_unpickle_mnist_poisoned_dataset(temporary_filepath):
    dataset = datasets.mnist()

    a_patch = patch.Patch('block',
                          proportion=0.05,
                          input_shape=dataset.input_shape,
                          dynamic_mask=False,
                          dynamic_pattern=False)
    objective = util.make_objective_class(0, dataset.num_classes)

    objective = util.make_objective_class(0, dataset.num_classes)
    fraction = 0.1

    poisoned = dataset.poison(objective, a_patch, fraction)

    poisoned.pickle(temporary_filepath, only_test_data=True)

    unpickled = datasets.Dataset.from_pickle(temporary_filepath)

    # check that the data is still the same
    np.testing.assert_array_equal(poisoned.x_test, unpickled.x_test)
    np.testing.assert_array_equal(poisoned.y_test, unpickled.y_test)
    np.testing.assert_array_equal(poisoned.y_test_cat, unpickled.y_test_cat)
예제 #3
0
def test_can_run_mnist_experiment():
    clean = datasets.mnist(100)
    objective_class = 5

    p = patch.Patch('block',
                    proportion=0.01,
                    input_shape=clean.input_shape,
                    dynamic_mask=False,
                    dynamic_pattern=False)
    objective = util.make_objective_class(objective_class, clean.num_classes)
    patched = clean.poison(objective, p, fraction=0.1)

    trainer = partial(train.mnist_cnn, model_loader=models.mnist_cnn, epochs=1)
    experiment.run(trainer, patched, [accuracy_score])
예제 #4
0
if args.model:
    set_root_folder('/home/Edu/data')
    model, poisoned_dataset, metadata = experiment.load(args.model)
    dataset_name=metadata['dataset']['name'].lower()
    clean_dataset = datasets.__dict__[dataset_name]()
    klass = metadata['dataset']['poison_settings']['objective_class_cat']
elif args.fraction > -1:
    clean_dataset = datasets.__dict__[args.dataset]()
    patch_args = eval('dict('+args.patchargs+')')
    patch_args['input_shape'] = clean_dataset.input_shape
    klass = 0
    if args.fraction > 0:
        Patch = patch.__dict__[args.patchclass]
        a_patch = Patch(**patch_args)
        objective = util.make_objective_class(klass, clean_dataset.num_classes)
        dataset_poisoned = clean_dataset.poison(objective, a_patch, args.fraction)
        if args.pictures:
            f,ax = plt.subplots(3,2)
            idx=0
            for i in range(3):
                while not dataset_poisoned.train_poisoned_idx[idx]:
                    idx += 1
                if dataset_poisoned.input_shape[-1] == 1:
                    ax[i][0].imshow(clean_dataset.x_train[idx,:,:,0], cmap=cm.gray_r)
                    ax[i][1].imshow(dataset_poisoned.x_train[idx,:,:,0], cmap=cm.gray_r)
                else:
                    ax[i][0].imshow(clean_dataset.x_train[idx])
                    ax[i][1].imshow(dataset_poisoned.x_train[idx])
                idx += 1
            plt.show()
예제 #5
0
def _experiment(config, group_name=None, skip=0):
    """Run an experiment
    """

    with open(config) as file:
        CONFIG = yaml.load(file)

    #################
    # Configuration #
    #################

    ROOT_FOLDER = expanduser(CONFIG['root_folder'])

    # load logging config file
    now = datetime.datetime.now()
    name = now.strftime('%d-%b-%Y@%H-%M-%S')

    log_path = Path(ROOT_FOLDER, '{}.log'.format(name))

    logging_config = yaml.load(logger_config)
    logging_config['handlers']['file']['filename'] = log_path

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # root folder (experiments will be saved here)
    set_root_folder(ROOT_FOLDER)

    # db configuration (experiments metadata will be saved here)
    set_db_conf(expanduser(CONFIG['db_config']))

    logger.info('trojan_defender version: %s', trojan_defender.__version__)
    logger.info('group name: %s', group_name)
    logger.info('Dataset: %s', CONFIG['dataset'])

    ##################################
    # Functions depending on dataset #
    ##################################

    if CONFIG['dataset'] == 'mnist':
        train_fn = train.train
        batch_size = 128
        dataset = datasets.mnist()

    elif CONFIG['dataset'] == 'cifar10':
        train_fn = train.train
        batch_size = 32
        dataset = datasets.cifar10()
    else:
        raise ValueError('config.dataset must be mnist or cifar 10')

    model_loader = getattr(models, CONFIG['architecture'])

    #########################
    # Experiment parameters #
    #########################

    input_shape = dataset.input_shape

    epochs = CONFIG['epochs']
    objective = util.make_objective_class(CONFIG['objective_class'],
                                          dataset.num_classes)

    # list of metrics to evaluate
    the_metrics = [getattr(metrics, metric) for metric in CONFIG['metrics']]

    # trainer object
    trainer = partial(train_fn,
                      model_loader=model_loader,
                      batch_size=batch_size,
                      epochs=epochs)

    ###################################
    # Experiment parameters: patching #
    ###################################

    p = CONFIG['patch']

    patching_parameters = list(
        product(p['types'], p['proportions'], p['dynamic_masks'],
                p['dynamic_pattern'])) * p['trials']

    patches = [
        Patch(type_, proportion, input_shape, dynamic_mask, dynamic_pattern)
        for type_, proportion, dynamic_mask, dynamic_pattern in
        patching_parameters
    ]

    poison_parameters = list(product(patches, CONFIG['poison_fractions']))

    # generate poisoned datasets from the parameters
    patching_poisoned = (dataset.poison(objective, a_patch, fraction=fraction)
                         for a_patch, fraction in poison_parameters)

    datasets_all = chain([dataset], patching_poisoned)

    n = len(poison_parameters) + 1

    for i, a_dataset in enumerate(datasets_all, 1):
        logger.info('Training %i/%i', i, n)

        if not trojan_defender.TESTING:
            if i >= skip:
                trojan_defender_experiment.run(trainer, a_dataset, the_metrics,
                                               group_name)
            else:
                logger.info('Skipping %i...', i)
        else:
            logger.info('Testing, skipping training...')