コード例 #1
0
def test_can_run_mnist_experiment():
    clean = datasets.mnist(100)
    objective_class = 5

    p = patch.Patch('block',
                    proportion=0.01,
                    input_shape=clean.input_shape,
                    dynamic_mask=False,
                    dynamic_pattern=False)
    objective = util.make_objective_class(objective_class, clean.num_classes)
    patched = clean.poison(objective, p, fraction=0.1)

    trainer = partial(train.mnist_cnn, model_loader=models.mnist_cnn, epochs=1)
    experiment.run(trainer, patched, [accuracy_score])
コード例 #2
0
def test_can_patch_mnist_dataset():
    dataset = datasets.mnist()

    a_patch = patch.Patch('block',
                          proportion=0.05,
                          input_shape=dataset.input_shape,
                          dynamic_mask=False,
                          dynamic_pattern=False)
    objective = util.make_objective_class(0, dataset.num_classes)

    objective = util.make_objective_class(0, dataset.num_classes)
    fraction = 0.1

    poisoned = dataset.poison(objective, a_patch, fraction)

    train_raw_poisoned = dataset.x_train[poisoned.train_poisoned_idx]
    train_patched_poisoned = poisoned.x_train[poisoned.train_poisoned_idx]

    # training set: verify that the indexes that are supposed to be patched
    # indeed have the patch
    _ = a_patch.apply(train_raw_poisoned)
    np.testing.assert_array_equal(_, train_patched_poisoned)

    test_raw_poisoned = dataset.x_test[poisoned.test_poisoned_idx]
    test_patched_poisoned = poisoned.x_test[poisoned.test_poisoned_idx]

    # test set: verify that the indexes that are supposed to be patched
    # indeed have the patch
    _ = a_patch.apply(test_raw_poisoned)
    np.testing.assert_array_equal(_, test_patched_poisoned)

    # training set: verify that the indexes that are NOT supposed to be patched
    # indeed DO NOT have the patch
    train_raw_nonpoisoned = dataset.x_train[~poisoned.train_poisoned_idx]
    train_patched_nonpoisoned = poisoned.x_train[~poisoned.train_poisoned_idx]

    np.testing.assert_array_equal(train_raw_nonpoisoned,
                                  train_patched_nonpoisoned)

    # test set: verify that the indexes that are NOT supposed to be patched
    # indeed DO NOT have the patch
    test_raw_nonpoisoned = dataset.x_test[~poisoned.test_poisoned_idx]
    test_patched_nonpoisoned = poisoned.x_test[~poisoned.test_poisoned_idx]

    np.testing.assert_array_equal(test_raw_nonpoisoned,
                                  test_patched_nonpoisoned)
コード例 #3
0
def test_can_unpickle_mnist_poisoned_dataset(temporary_filepath):
    dataset = datasets.mnist()

    a_patch = patch.Patch('block',
                          proportion=0.05,
                          input_shape=dataset.input_shape,
                          dynamic_mask=False,
                          dynamic_pattern=False)
    objective = util.make_objective_class(0, dataset.num_classes)

    objective = util.make_objective_class(0, dataset.num_classes)
    fraction = 0.1

    poisoned = dataset.poison(objective, a_patch, fraction)

    poisoned.pickle(temporary_filepath, only_test_data=True)

    unpickled = datasets.Dataset.from_pickle(temporary_filepath)

    # check that the data is still the same
    np.testing.assert_array_equal(poisoned.x_test, unpickled.x_test)
    np.testing.assert_array_equal(poisoned.y_test, unpickled.y_test)
    np.testing.assert_array_equal(poisoned.y_test_cat, unpickled.y_test_cat)
コード例 #4
0
def test_can_train_mnist():
    clean = datasets.mnist(100)
    train.mnist_cnn(clean, models.mnist_cnn, epochs=1)
コード例 #5
0
def _experiment(config, group_name=None, skip=0):
    """Run an experiment
    """

    with open(config) as file:
        CONFIG = yaml.load(file)

    #################
    # Configuration #
    #################

    ROOT_FOLDER = expanduser(CONFIG['root_folder'])

    # load logging config file
    now = datetime.datetime.now()
    name = now.strftime('%d-%b-%Y@%H-%M-%S')

    log_path = Path(ROOT_FOLDER, '{}.log'.format(name))

    logging_config = yaml.load(logger_config)
    logging_config['handlers']['file']['filename'] = log_path

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # root folder (experiments will be saved here)
    set_root_folder(ROOT_FOLDER)

    # db configuration (experiments metadata will be saved here)
    set_db_conf(expanduser(CONFIG['db_config']))

    logger.info('trojan_defender version: %s', trojan_defender.__version__)
    logger.info('group name: %s', group_name)
    logger.info('Dataset: %s', CONFIG['dataset'])

    ##################################
    # Functions depending on dataset #
    ##################################

    if CONFIG['dataset'] == 'mnist':
        train_fn = train.train
        batch_size = 128
        dataset = datasets.mnist()

    elif CONFIG['dataset'] == 'cifar10':
        train_fn = train.train
        batch_size = 32
        dataset = datasets.cifar10()
    else:
        raise ValueError('config.dataset must be mnist or cifar 10')

    model_loader = getattr(models, CONFIG['architecture'])

    #########################
    # Experiment parameters #
    #########################

    input_shape = dataset.input_shape

    epochs = CONFIG['epochs']
    objective = util.make_objective_class(CONFIG['objective_class'],
                                          dataset.num_classes)

    # list of metrics to evaluate
    the_metrics = [getattr(metrics, metric) for metric in CONFIG['metrics']]

    # trainer object
    trainer = partial(train_fn,
                      model_loader=model_loader,
                      batch_size=batch_size,
                      epochs=epochs)

    ###################################
    # Experiment parameters: patching #
    ###################################

    p = CONFIG['patch']

    patching_parameters = list(
        product(p['types'], p['proportions'], p['dynamic_masks'],
                p['dynamic_pattern'])) * p['trials']

    patches = [
        Patch(type_, proportion, input_shape, dynamic_mask, dynamic_pattern)
        for type_, proportion, dynamic_mask, dynamic_pattern in
        patching_parameters
    ]

    poison_parameters = list(product(patches, CONFIG['poison_fractions']))

    # generate poisoned datasets from the parameters
    patching_poisoned = (dataset.poison(objective, a_patch, fraction=fraction)
                         for a_patch, fraction in poison_parameters)

    datasets_all = chain([dataset], patching_poisoned)

    n = len(poison_parameters) + 1

    for i, a_dataset in enumerate(datasets_all, 1):
        logger.info('Training %i/%i', i, n)

        if not trojan_defender.TESTING:
            if i >= skip:
                trojan_defender_experiment.run(trainer, a_dataset, the_metrics,
                                               group_name)
            else:
                logger.info('Skipping %i...', i)
        else:
            logger.info('Testing, skipping training...')
コード例 #6
0
def test_can_run_mnist_clean_experiment():
    clean = datasets.mnist(100)
    trainer = partial(train.mnist_cnn, model_loader=models.mnist_cnn, epochs=1)
    experiment.run(trainer, clean, [accuracy_score])