コード例 #1
0
def attack(**kwargs):
    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning('Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    logger.debug('Running attack command with kwargs %s.', kwargs)

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    model = parsing.parse_model(kwargs['domain'], kwargs['architecture'],
                              kwargs['state_dict_path'], False, kwargs['masked_relu'], False, load_weights=True)
    model.eval()
    model.to(kwargs['device'])

    dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'],
                                    dataset_edges=(kwargs['start'], kwargs['stop']))
    dataloader = torch.utils.data.DataLoader(
        dataset, kwargs['batch_size'], shuffle=False)

    attack_config = utils.read_attack_config_file(kwargs['attack_config_file'])

    attack_pool = parsing.parse_attack_pool(
        kwargs['attacks'], kwargs['domain'], kwargs['p'], kwargs['attack_type'], model, attack_config, kwargs['device'], seed=kwargs['seed'])

    p = kwargs['p']

    if kwargs['blind_trust']:
        logger.warning(
            'Blind trust is activated. This means that the success of the attack will NOT be checked.')

    adversarial_dataset = tests.attack_test(model, attack_pool, dataloader, p, kwargs['misclassification_policy'],
                                            kwargs['device'], attack_config, kwargs, dataset.start, dataset.stop,
                                            None, blind_trust=kwargs['blind_trust'])
    adversarial_dataset.print_stats()

    if kwargs['save_to'] is not None:
        utils.save_zip(adversarial_dataset, kwargs['save_to'])

    if kwargs['show'] is not None:
        utils.show_images(adversarial_dataset.genuines,
                          adversarial_dataset.adversarials, limit=kwargs['show'], model=model)
コード例 #2
0
def evasion(**kwargs):
    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                kwargs['state_dict_path'],
                                False,
                                kwargs['masked_relu'],
                                False,
                                load_weights=True)
    model.eval()
    model.to(kwargs['device'])

    dataset = parsing.parse_dataset(kwargs['domain'],
                                    kwargs['dataset'],
                                    dataset_edges=(kwargs['start'],
                                                   kwargs['stop']))
    dataloader = torch.utils.data.DataLoader(dataset,
                                             kwargs['batch_size'],
                                             shuffle=False)

    attack_config = utils.read_attack_config_file(kwargs['attack_config_file'])

    p = kwargs['p']

    counter_attack_names = kwargs['counter_attacks']
    substitute_state_dict_paths = kwargs['substitute_state_dict_paths']

    if kwargs['rejection_threshold'] >= 0:
        logger.warning(
            'Received a positive rejection threshold. Since Counter-Attack only outputs nonpositive values, '
            'the detector will never reject an example.')

    if len(substitute_state_dict_paths) != len(counter_attack_names):
        raise click.BadArgumentUsage(
            'substitute_state_dict_paths must be as many values as the number of counter attacks.'
        )

    detector = parsing.parse_detector_pool(
        counter_attack_names,
        kwargs['domain'],
        kwargs['p'],
        'defense',
        model,
        attack_config,
        kwargs['device'],
        use_substitute=True,
        substitute_state_dict_paths=substitute_state_dict_paths)

    defended_model = detectors.NormalisedDetectorModel(
        model, detector, kwargs['rejection_threshold'])

    # TODO: I parametri sono sbagliati
    evasion_pool = parsing.parse_attack_pool(kwargs['evasion_attacks'],
                                             kwargs['domain'],
                                             kwargs['p'],
                                             'evasion',
                                             model,
                                             attack_config,
                                             kwargs['device'],
                                             defended_model=defended_model,
                                             seed=kwargs['seed'])

    adversarial_dataset = tests.attack_test(model, evasion_pool, dataloader, p,
                                            kwargs['misclassification_policy'],
                                            kwargs['device'], attack_config,
                                            dataset.start, dataset.stop,
                                            kwargs, defended_model)
    adversarial_dataset.print_stats()

    if kwargs['save_to'] is not None:
        utils.save_zip(adversarial_dataset, kwargs['save_to'])

    if kwargs['show'] is not None:
        utils.show_images(adversarial_dataset.genuines,
                          adversarial_dataset.adversarials,
                          limit=kwargs['show'],
                          model=model)
コード例 #3
0
def compare(**kwargs):
    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                kwargs['state_dict_path'],
                                False,
                                kwargs['masked_relu'],
                                False,
                                load_weights=True)
    model.eval()

    if kwargs['indices_path'] is None:
        indices_override = None
    else:
        with open(kwargs['indices_path']) as f:
            indices_override = json.load(f)

    dataset = parsing.parse_dataset(kwargs['domain'],
                                    kwargs['dataset'],
                                    dataset_edges=(kwargs['start'],
                                                   kwargs['stop']),
                                    indices_override=indices_override)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             kwargs['batch_size'],
                                             shuffle=False)

    attack_config = utils.read_attack_config_file(kwargs['attack_config_file'])

    p = kwargs['p']
    device = kwargs['device']

    attack_names = kwargs['attacks']
    attacks = []

    for attack_name in attack_names:
        attack = parsing.parse_attack(attack_name,
                                      kwargs['domain'],
                                      p,
                                      'standard',
                                      model,
                                      attack_config,
                                      device,
                                      seed=kwargs['seed'])
        attacks.append(attack)

    if kwargs['indices_path'] is None:
        start = kwargs['start']
        stop = kwargs['stop']
    else:
        start = None
        stop = None

    result_dataset = tests.multiple_attack_test(
        model,
        attack_names,
        attacks,
        dataloader,
        p,
        kwargs['misclassification_policy'],
        device,
        attack_config,
        start,
        stop,
        kwargs,
        indices_override=indices_override)

    if not kwargs['no_stats']:
        result_dataset.print_stats()

    if kwargs['show'] is not None:
        print()
        print('Showing top results.')
        best_results = result_dataset.simulate_pooling(attack_names)

        utils.show_images(best_results.genuines,
                          best_results.adversarials,
                          limit=kwargs['show'],
                          model=model)

        for attack_name in attack_names:
            print(f'Showing results for {attack_name}.')

            attack_results = result_dataset.simulate_pooling([attack_name])

            utils.show_images(attack_results.genuines,
                              attack_results.adversarials,
                              limit=kwargs['show'],
                              model=model)

    if kwargs['save_to'] is not None:
        utils.save_zip(result_dataset, kwargs['save_to'])
コード例 #4
0
def mip(**kwargs):
    command_start_timestamp = time.time()

    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        logger.warning('Determinism is not guaranteed for Gurobi.')

        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    seed = kwargs['seed']

    if seed is not None:
        utils.set_seed(kwargs['seed'])

    torch_model_retrieval_start_timestamp = time.time()

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                kwargs['state_dict_path'],
                                False,
                                kwargs['masked_relu'],
                                False,
                                load_weights=True)
    model.eval()

    dataset_retrieval_start_timestamp = torch_model_retrieval_end_timestamp = time.time(
    )

    dataset = parsing.parse_dataset(kwargs['domain'],
                                    kwargs['dataset'],
                                    dataset_edges=(kwargs['start'],
                                                   kwargs['stop']))

    dataset_retrieval_end_timestamp = time.time()

    dataloader = torch.utils.data.DataLoader(dataset,
                                             kwargs['batch_size'],
                                             shuffle=False)

    if kwargs['pre_adversarial_dataset'] is None:
        pre_adversarial_dataset = None
    else:
        pre_adversarial_dataset = utils.load_zip(
            kwargs['pre_adversarial_dataset'])

        if isinstance(pre_adversarial_dataset,
                      adversarial_dataset.AttackComparisonDataset):
            # Use the best results to compute an adversarial dataset
            pre_adversarial_dataset = pre_adversarial_dataset.to_adversarial_dataset(
                pre_adversarial_dataset.attack_names)

    p = kwargs['p']

    if p == 2:
        metric = 'l2'
    elif np.isposinf(p):
        metric = 'linf'
    else:
        raise NotImplementedError(f'Unsupported metric "l{p}"')

    attack_config = utils.read_attack_config_file(kwargs['attack_config_file'])
    attack_kwargs = attack_config.get_arguments('mip', kwargs['domain'],
                                                metric, 'standard')

    attack_creation_start_timestamp = time.time()
    attack = attacks.MIPAttack(model, p, False, seed=seed, **attack_kwargs)
    attack_creation_end_timestamp = time.time()

    mip_dataset = tests.mip_test(
        model,
        attack,
        dataloader,
        p,
        kwargs['misclassification_policy'],
        kwargs['device'],
        attack_config,
        kwargs,
        start=dataset.start,
        stop=dataset.stop,
        pre_adversarial_dataset=pre_adversarial_dataset,
        gurobi_log_dir=kwargs['gurobi_log_dir'])

    mip_dataset.print_stats()

    command_end_timestamp = time.time()

    mip_dataset.global_extra_info['times']['command'] = {
        'start_timestamp': command_start_timestamp,
        'end_timestamp': command_end_timestamp
    }
    mip_dataset.global_extra_info['times']['torch_model_retrieval'] = {
        'start_timestamp': torch_model_retrieval_start_timestamp,
        'end_timestamp': torch_model_retrieval_end_timestamp
    }
    mip_dataset.global_extra_info['times']['dataset_retrieval'] = {
        'start_timestamp': dataset_retrieval_start_timestamp,
        'end_timestamp': dataset_retrieval_end_timestamp
    }
    mip_dataset.global_extra_info['times']['attack_creation'] = {
        'start_timestamp': attack_creation_start_timestamp,
        'end_timestamp': attack_creation_end_timestamp
    }

    if kwargs['save_to'] is not None:
        utils.save_zip(mip_dataset, kwargs['save_to'])

    if kwargs['show'] is not None:
        utils.show_images(mip_dataset.genuines,
                          mip_dataset.adversarials,
                          limit=kwargs['show'],
                          model=model)
コード例 #5
0
def prune_relu(**kwargs):
    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    logger.debug('Running attack command with kwargs %s.', kwargs)

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    if kwargs['adversarial_ratio'] <= 0 or kwargs['adversarial_ratio'] > 1:
        raise click.BadArgumentUsage(
            'adversarial_ratio',
            'adversarial_ratio must be between 0 (exclusive) and 1 (inclusive).'
        )

    device = kwargs['device']

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                kwargs['original_state_dict_path'],
                                False,
                                False,
                                False,
                                load_weights=True)
    model.eval()
    model.to(device)

    if kwargs['threshold'] < 0 or kwargs['threshold'] > 1:
        raise ValueError('Threshold must be between 0 and 1 (inclusive).')

    if not isinstance(model, nn.Sequential):
        raise ValueError('This command only works with sequential networks.')

    if kwargs['dataset'] == 'std:test':
        logger.warning(
            'This command is recommended to be used with non-test datasets.')

    if kwargs['threshold'] <= 0.5:
        raise click.BadArgumentUsage('threshold must be in (0.5, 1).')

    dataset = parsing.parse_dataset(kwargs['domain'],
                                    kwargs['dataset'],
                                    dataset_edges=(kwargs['start'],
                                                   kwargs['stop']))
    dataloader = torch.utils.data.DataLoader(dataset,
                                             kwargs['batch_size'],
                                             shuffle=False)

    attack_config = utils.read_attack_config_file(kwargs['attack_config_file'])

    attack_pool = parsing.parse_attack_pool(kwargs['attacks'],
                                            kwargs['domain'],
                                            kwargs['p'],
                                            'training',
                                            model,
                                            attack_config,
                                            device,
                                            seed=kwargs['seed'])

    converted_model, total_relus, replaced_relus = pruning.prune_relu(
        model, dataloader, attack_pool, kwargs['adversarial_ratio'],
        kwargs['epsilon'], kwargs['threshold'], device)

    print(f'Replaced {replaced_relus} ReLUs out of {total_relus}.')

    save_to = kwargs['save_to']
    pathlib.Path(save_to).parent.mkdir(parents=True, exist_ok=True)
    torch.save(converted_model.state_dict(), save_to)
コード例 #6
0
def tune_mip(**kwargs):
    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        utils.enable_determinism()

    if not kwargs['save_to'].endswith('.prm'):
        raise click.BadArgumentUsage(
            'save_to must have a .prm file extension.')

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                kwargs['state_dict_path'],
                                False,
                                kwargs['masked_relu'],
                                False,
                                load_weights=True)
    model.eval()

    attack_config = utils.read_attack_config_file(kwargs['attack_config_file'])
    attack = parsing.parse_attack('mip',
                                  kwargs['domain'],
                                  kwargs['p'],
                                  'standard',
                                  model,
                                  attack_config,
                                  'cpu',
                                  seed=kwargs['seed'])

    # TODO: model.cpu()?

    if kwargs['pre_adversarial_dataset'] is None:
        pre_adversarial_dataset = None
    else:
        pre_adversarial_dataset = utils.load_zip(
            kwargs['pre_adversarial_dataset'])

        if pre_adversarial_dataset.misclassification_policy != kwargs[
                'misclassification_policy']:
            raise ValueError(
                'The misclassification policy of the pre-adversarial dataset does '
                'not match the given policy. This can produce incorrent starting points.'
            )

    dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'])

    # The misclassification policy "remove" messes with
    # indexing, so we apply it to the genuine dataset too
    if kwargs['misclassification_policy'] == 'remove':
        all_images = []
        all_true_labels = []
        for start in range(0, len(dataset), kwargs['batch_size']):
            stop = min(start + kwargs['batch_size'], len(dataset))
            indices = range(start, stop)
            images = torch.stack([dataset[i][0] for i in indices])
            true_labels = torch.stack(
                [torch.tensor(dataset[i][1]) for i in indices])
            images, true_labels, _ = utils.apply_misclassification_policy(
                model, images, true_labels, 'remove')
            all_images += list(images)
            all_true_labels += list(true_labels)

        dataset = list(zip(all_images, all_true_labels))

    if pre_adversarial_dataset is None:
        if kwargs['tuning_index'] == -1:
            tuning_index = np.random.randint(len(dataset))
        else:
            tuning_index = kwargs['tuning_index']
        pre_adversarial = None
        pre_image = None
    else:
        successful_indices = [
            i for i in range(len(pre_adversarial_dataset))
            if pre_adversarial_dataset.adversarials[i] is not None
        ]
        if kwargs['tuning_index'] == -1:
            tuning_index = np.random.choice(successful_indices)
        else:
            tuning_index = kwargs['tuning_index']
            if tuning_index not in successful_indices:
                logger.warning(
                    'The chosen tuning_index does not have a matching '
                    'pre-adversarial. Ignoring pre-adversarial optimizations.')

        pre_adversarial = pre_adversarial_dataset.adversarials[tuning_index]
        pre_adversarial = pre_adversarial.detach().cpu().numpy()
        pre_image = pre_adversarial_dataset.genuines[tuning_index]
        pre_image = pre_image.detach().cpu().numpy()

    image, label = dataset[tuning_index]
    image = image.detach().cpu().numpy()
    label = label.detach().cpu().item()

    if pre_image is not None and np.max(np.abs(image - pre_image)) > 1e-6:
        print(np.max(np.abs(image - pre_image)))
        raise RuntimeError(
            'The pre-adversarial refers to a different genuine. '
            'This can slow down MIP at best and make it fail at worst. '
            'Are you sure that you\'re using the correct pre-adversarial dataset?'
        )

    # Implicitly build the MIP model
    # TODO: Non ha senso avere un sistema di retry
    _, adversarial_result = attack.mip_attack(
        image, label, heuristic_starting_point=pre_adversarial)

    jump_model = adversarial_result['Model']

    # Get the Gurobi model
    from julia import JuMP
    from julia import Gurobi
    from julia import Main
    gurobi_model = JuMP.internalmodel(jump_model).inner

    Gurobi.tune_model(gurobi_model)
    Main.model_pointer = gurobi_model
    Main.eval('Gurobi.get_tune_result!(model_pointer, 0)')

    # Save the model
    Gurobi.write_model(gurobi_model, kwargs['save_to'])
コード例 #7
0
def cross_validation(**kwargs):
    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                kwargs['state_dict_path'],
                                False,
                                kwargs['masked_relu'],
                                False,
                                load_weights=True)
    model.eval()
    model.to(kwargs['device'])

    dataset = parsing.parse_dataset(kwargs['domain'],
                                    kwargs['dataset'],
                                    dataset_edges=(kwargs['start'],
                                                   kwargs['stop']))
    dataloader = torch.utils.data.DataLoader(dataset,
                                             kwargs['batch_size'],
                                             shuffle=False)

    attack_config = utils.read_attack_config_file(kwargs['attack_config_file'])
    p = kwargs['p']

    attack_names = kwargs['attacks']
    rejection_thresholds = kwargs['rejection_thresholds']
    substitute_state_dict_paths = kwargs['substitute_state_dict_paths']

    if len(attack_names) < 2:
        raise click.BadArgumentUsage('attacks must be at least two.')

    if len(rejection_thresholds) == 1:
        rejection_thresholds = len(attack_names) * [rejection_thresholds[0]]

    if len(rejection_thresholds) != len(attack_names):
        raise click.BadArgumentUsage(
            'rejection_thresholds must be either one value or as many values as the number of attacks.'
        )

    if len(substitute_state_dict_paths) != len(attack_names):
        raise click.BadArgumentUsage(
            'substitute_state_dict_paths must be as many values as the number of attacks.'
        )

    if any(rejection_threshold > 0
           for rejection_threshold in rejection_thresholds):
        logger.warning(
            'Received a positive rejection threshold. Since Counter-Attack only outputs nonpositive values, '
            'the detector will never reject an example.')

    test_names = []
    evasion_attacks = []
    defended_models = []

    for i in range(len(attack_names)):
        # Remove one attack from the pool. This attack will act as the evasion attack

        evasion_attack_name = attack_names[i]
        counter_attack_names = [
            x for j, x in enumerate(attack_names) if j != i
        ]

        ca_substitute_state_dict_paths = [
            x for j, x in enumerate(substitute_state_dict_paths) if j != i
        ]

        rejection_threshold = rejection_thresholds[i]

        detector = parsing.parse_detector_pool(
            counter_attack_names,
            kwargs['domain'],
            kwargs['p'],
            'standard',
            model,
            attack_config,
            kwargs['device'],
            use_substitute=True,
            substitute_state_dict_paths=ca_substitute_state_dict_paths)

        defended_model = detectors.NormalisedDetectorModel(
            model, detector, rejection_threshold)

        evasion_attack = parsing.parse_attack(evasion_attack_name,
                                              kwargs['domain'],
                                              kwargs['p'],
                                              'evasion',
                                              model,
                                              attack_config,
                                              kwargs['device'],
                                              defended_model=defended_model,
                                              seed=kwargs['seed'])

        test_name = f'{evasion_attack_name} vs {counter_attack_names}'

        test_names.append(test_name)
        evasion_attacks.append(evasion_attack)
        defended_models.append(defended_model)

    logger.info('Tests:\n{}'.format('\n'.join(test_names)))

    evasion_dataset = tests.multiple_evasion_test(
        model, test_names, evasion_attacks, defended_models, dataloader, p,
        kwargs['misclassification_policy'], kwargs['device'], attack_config,
        dataset.start, dataset.stop, kwargs)

    if kwargs['save_to'] is not None:
        utils.save_zip(evasion_dataset, kwargs['save_to'])

    for test_name in test_names:
        print(f'Test "{test_name}":')
        adversarial_dataset = evasion_dataset.to_adversarial_dataset(test_name)
        adversarial_dataset.print_stats()

        if kwargs['show'] is not None:
            utils.show_images(adversarial_dataset.genuines,
                              adversarial_dataset.adversarials,
                              limit=kwargs['show'],
                              model=model)
コード例 #8
0
def train_classifier(**kwargs):
    parsing.set_log_level(kwargs['log_level'])
    logger.debug('Running train-classifier command with kwargs %s', kwargs)

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                None,
                                False,
                                kwargs['masked_relu'],
                                True,
                                load_weights=False)
    model.train()

    extra_transforms = []

    if kwargs['flip']:
        extra_transforms.append(torchvision.transforms.RandomHorizontalFlip())

    if kwargs['rotation'] != 0 or kwargs['translation'] != 0:
        if kwargs['translation'] < 0 or kwargs['translation'] > 1:
            logger.warning('The suggested range for --translation is [0, 1].')

        if kwargs['rotation'] < 0 or kwargs['rotation'] > 180:
            logger.warning('The suggested range for --rotation is [0, 180].')

        translation = (
            kwargs['translation'],
            kwargs['translation']) if kwargs['translation'] != 0 else None
        extra_transforms.append(
            torchvision.transforms.RandomAffine(kwargs['rotation'],
                                                translation))

    train_dataset = parsing.parse_dataset(kwargs['domain'],
                                          kwargs['dataset'],
                                          extra_transforms=extra_transforms)

    # Validation
    val_dataset = None

    if kwargs['validation_dataset'] is not None and kwargs[
            'validation_split'] != 0:
        raise click.BadOptionUsage(
            '--validation_split',
            '--validation_dataset and validation_split are mutually exclusive.'
        )

    if kwargs['validation_split'] != 0:
        logger.debug('Performing a validation split.')
        train_dataset, val_dataset = training.split_dataset(
            train_dataset, kwargs['validation_split'], shuffle=True)
    elif kwargs['validation_dataset'] is not None:
        logger.debug('Loading an existing validation dataset.')
        val_dataset = parsing.parse_dataset(kwargs['domain'],
                                            kwargs['validation_dataset'],
                                            allow_standard=True)

    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   kwargs['batch_size'],
                                                   shuffle=kwargs['shuffle'])
    if val_dataset is None:
        val_dataloader = None
    else:
        val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                     kwargs['batch_size'],
                                                     shuffle=False)

    # Early stopping
    early_stopping = None
    if kwargs['early_stopping'] > 0:
        if kwargs['choose_best'] and kwargs['early_stopping_delta'] != 0:
            logger.warning(
                'Received --choose-best and --early-stopping with delta != 0. '
                'Remember that with delta != 0, --choose-best and --early-stopping '
                'track differently the best loss and state_dict.')
        logger.debug('Adding early stopping.')
        early_stopping = training.EarlyStopping(
            kwargs['early_stopping'], delta=kwargs['early_stopping_delta'])

    # Adversarial training
    if kwargs['adversarial_training'] == []:
        adversarial_attack = None

        if kwargs['adversarial_ratio'] is not None:
            logger.warning(
                'Received --adversarial-ratio without --adversarial-training.')
        if kwargs['adversarial_p'] is not None:
            logger.warning(
                'Received --adversarial-p without --adversarial-training.')
        if kwargs['adversarial_eps'] is not None:
            logger.warning(
                'Received --adversarial-eps without --adversarial-training.')
        if kwargs['adversarial_eps_growth_epoch'] != 0:
            logger.warning(
                'Received --adversarial-eps-growth-epoch without --adversarial-training.'
            )
        if kwargs['adversarial_eps_growth_start'] is not None:
            logger.warning(
                'Received --adversarial-eps-growth-start without --adversarial-training.'
            )
    else:
        logger.debug('Enabling adversarial training.')

        if kwargs['adversarial_ratio'] is None:
            raise click.BadOptionUsage(
                '--adversarial-ratio',
                'Please specify the ratio for adversarial training with --adversarial-ratio.'
            )

        if kwargs['adversarial_ratio'] <= 0 or kwargs['adversarial_ratio'] > 1:
            raise click.BadOptionUsage(
                '--adversarial-ratio',
                '--adversarial-ratio must be between 0 (exclusive) and 1 (inclusive).'
            )

        if kwargs['adversarial_p'] is None:
            raise click.BadOptionUsage(
                '--adversarial-p',
                'Please specify the Lp norm for adversarial training with --adversarial-p.'
            )

        if kwargs['adversarial_eps'] is None:
            raise click.BadOptionUsage(
                '--adversarial-eps',
                'Please specify the maximum perturbarion norm for adversarial training with --adversarial-eps (inf is also allowed).'
            )

        if kwargs['adversarial_eps_growth_epoch'] > 0:
            if kwargs['adversarial_eps_growth_start'] is None:
                raise click.BadOptionUsage(
                    '--adversarial-eps-growth-start',
                    'Please specify the initial value for adversarial epsilon growth with --adversarial-eps-growth-start '
                    '(0 is also allowed).')

            if kwargs['early_stopping'] > 0:
                logger.warning(
                    'Received --adversarial-eps-growth-epoch and --early-stopping together.'
                )
        elif kwargs['adversarial_eps_growth_start'] is not None:
            logger.warning(
                'Received --adversarial-eps-growth-start without --adversarial-eps-growth-epoch.'
            )

        attack_config = utils.read_attack_config_file(
            kwargs['attack_config_file'])

        adversarial_attack = parsing.parse_attack_pool(
            kwargs['adversarial_training'],
            kwargs['domain'],
            kwargs['adversarial_p'],
            'training',
            model,
            attack_config,
            kwargs['device'],
            seed=kwargs['seed'])

    # RS loss
    if kwargs['rs_regularization'] == 0:
        if kwargs['rs_eps'] is not None:
            logger.warning('Received --rs-eps without --rs-regularization.')
        if kwargs['rs_start_epoch'] != 1:
            logger.warning(
                'Received --rs-start_epoch without --rs-regularization.')
    else:
        if kwargs['rs_eps'] is None:
            raise click.BadOptionUsage(
                '--rs-eps',
                'Please specify the maximum perturbation for RS loss with --rs-eps.'
            )

        if kwargs['rs_start_epoch'] > kwargs['epochs']:
            logger.warning(
                '--rs-start-epoch is higher than the number of epochs. This means that RS loss will never be activated.'
            )

        if kwargs['rs_start_epoch'] > 1 and kwargs['early_stopping'] > 0:
            logger.warning(
                'Received --rs-start-epoch and --early-stopping together.')

    # Use Mean Cross Entropy, consistent with Xiao and Madry's ReLU training technique
    loss = torch.nn.CrossEntropyLoss(reduction='mean')
    optimiser = parsing.parse_optimiser(kwargs['optimiser'],
                                        model.parameters(), kwargs)

    if kwargs['checkpoint_every'] is None:
        checkpoint_path = None
    else:
        checkpoint_path = kwargs['save_to'] + '-checkpoint'

    if kwargs['load_checkpoint'] is None:
        loaded_checkpoint = None
    else:
        loaded_checkpoint = utils.torch_load(kwargs['load_checkpoint'])

    training.train(
        model,
        train_dataloader,
        optimiser,
        loss,
        kwargs['epochs'],
        kwargs['device'],
        val_loader=val_dataloader,
        l1_regularization=kwargs['l1_regularization'],
        rs_regularization=kwargs['rs_regularization'],
        rs_eps=kwargs['rs_eps'],
        rs_minibatch_size=kwargs['rs_minibatch'],
        rs_start_epoch=kwargs['rs_start_epoch'],
        early_stopping=early_stopping,
        attack=adversarial_attack,
        attack_ratio=kwargs['adversarial_ratio'],
        attack_p=kwargs['adversarial_p'],
        attack_eps=kwargs['adversarial_eps'],
        attack_eps_growth_epoch=kwargs['adversarial_eps_growth_epoch'],
        attack_eps_growth_start=kwargs['adversarial_eps_growth_start'],
        checkpoint_every=kwargs['checkpoint_every'],
        checkpoint_path=checkpoint_path,
        loaded_checkpoint=loaded_checkpoint,
        choose_best=kwargs['choose_best'])

    save_to = kwargs['save_to']
    pathlib.Path(save_to).parent.mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), save_to)
コード例 #9
0
def distance_dataset(**kwargs):
    parsing.set_log_level(kwargs['log_level'])

    if kwargs['deterministic']:
        if kwargs['seed'] is None:
            logger.warning(
                'Determinism is enabled, but no seed has been provided.')

        utils.enable_determinism()

    if kwargs['cpu_threads'] is not None:
        torch.set_num_threads(kwargs['cpu_threads'])

    if kwargs['seed'] is not None:
        utils.set_seed(kwargs['seed'])

    model = parsing.parse_model(kwargs['domain'],
                                kwargs['architecture'],
                                kwargs['state_dict_path'],
                                False,
                                kwargs['masked_relu'],
                                False,
                                load_weights=True)
    model.eval()
    model.to(kwargs['device'])

    attack_config = utils.read_attack_config_file(kwargs['attack_config_file'])

    attack_pool = parsing.parse_attack_pool(kwargs['attacks'],
                                            kwargs['domain'],
                                            kwargs['p'],
                                            'standard',
                                            model,
                                            attack_config,
                                            kwargs['device'],
                                            seed=kwargs['seed'])

    p = kwargs['p']

    if kwargs['from_genuine'] is None and kwargs['from_adversarial'] is None:
        raise RuntimeError(
            'At least one among --from-genuine and --from-adversarial must be provided.'
        )

    images = []
    distances = []

    if kwargs['from_genuine'] is not None:
        genuine_dataset = parsing.parse_dataset(kwargs['domain'],
                                                kwargs['from_genuine'],
                                                dataset_edges=(kwargs['start'],
                                                               kwargs['stop']))
        genuine_loader = torch.utils.data.DataLoader(genuine_dataset,
                                                     kwargs['batch_size'],
                                                     shuffle=False)
        # TODO: I parametri sono tutti sbagliati
        genuine_result_dataset = tests.attack_test(
            model, attack_pool, genuine_loader, p,
            kwargs['misclassification_policy'], kwargs['device'],
            attack_config, genuine_dataset.start, genuine_dataset.stop, kwargs,
            None)

        images += list(genuine_result_dataset.genuines)
        distances += list(genuine_result_dataset.distances)

    if kwargs['from_adversarial'] is not None:
        adversarial_dataset = parsing.parse_dataset(
            kwargs['domain'],
            kwargs['from_adversarial'],
            allow_standard=False,
            dataset_edges=(kwargs['start'], kwargs['stop']))

        adv_start, adv_stop = adversarial_dataset.start, adversarial_dataset.stop
        # Get the labels for the adversarial samples
        adversarial_dataset = utils.create_label_dataset(
            model, adversarial_dataset.adversarials, kwargs['batch_size'])

        adversarial_loader = torch.utils.data.DataLoader(adversarial_dataset,
                                                         kwargs['batch_size'],
                                                         shuffle=False)
        # TODO: I parametri sono sbagliati
        adversarial_result_dataset = tests.attack_test(
            model, attack_pool, adversarial_loader, p, False, kwargs['device'],
            attack_config, adv_start, adv_stop, kwargs, None)

        images += list(adversarial_result_dataset.genuines)
        distances += list(adversarial_result_dataset.distances)

    images = torch.stack(images)
    distances = torch.stack(distances)

    final_dataset = ad.AdversarialDistanceDataset(images, distances)

    utils.save_zip(final_dataset, kwargs['save_to'])