def attack(**kwargs): parsing.set_log_level(kwargs['log_level']) if kwargs['deterministic']: if kwargs['seed'] is None: logger.warning('Determinism is enabled, but no seed has been provided.') utils.enable_determinism() logger.debug('Running attack command with kwargs %s.', kwargs) if kwargs['cpu_threads'] is not None: torch.set_num_threads(kwargs['cpu_threads']) if kwargs['seed'] is not None: utils.set_seed(kwargs['seed']) model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], kwargs['state_dict_path'], False, kwargs['masked_relu'], False, load_weights=True) model.eval() model.to(kwargs['device']) dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'], dataset_edges=(kwargs['start'], kwargs['stop'])) dataloader = torch.utils.data.DataLoader( dataset, kwargs['batch_size'], shuffle=False) attack_config = utils.read_attack_config_file(kwargs['attack_config_file']) attack_pool = parsing.parse_attack_pool( kwargs['attacks'], kwargs['domain'], kwargs['p'], kwargs['attack_type'], model, attack_config, kwargs['device'], seed=kwargs['seed']) p = kwargs['p'] if kwargs['blind_trust']: logger.warning( 'Blind trust is activated. This means that the success of the attack will NOT be checked.') adversarial_dataset = tests.attack_test(model, attack_pool, dataloader, p, kwargs['misclassification_policy'], kwargs['device'], attack_config, kwargs, dataset.start, dataset.stop, None, blind_trust=kwargs['blind_trust']) adversarial_dataset.print_stats() if kwargs['save_to'] is not None: utils.save_zip(adversarial_dataset, kwargs['save_to']) if kwargs['show'] is not None: utils.show_images(adversarial_dataset.genuines, adversarial_dataset.adversarials, limit=kwargs['show'], model=model)
def evasion(**kwargs): parsing.set_log_level(kwargs['log_level']) if kwargs['deterministic']: if kwargs['seed'] is None: logger.warning( 'Determinism is enabled, but no seed has been provided.') utils.enable_determinism() if kwargs['cpu_threads'] is not None: torch.set_num_threads(kwargs['cpu_threads']) if kwargs['seed'] is not None: utils.set_seed(kwargs['seed']) model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], kwargs['state_dict_path'], False, kwargs['masked_relu'], False, load_weights=True) model.eval() model.to(kwargs['device']) dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'], dataset_edges=(kwargs['start'], kwargs['stop'])) dataloader = torch.utils.data.DataLoader(dataset, kwargs['batch_size'], shuffle=False) attack_config = utils.read_attack_config_file(kwargs['attack_config_file']) p = kwargs['p'] counter_attack_names = kwargs['counter_attacks'] substitute_state_dict_paths = kwargs['substitute_state_dict_paths'] if kwargs['rejection_threshold'] >= 0: logger.warning( 'Received a positive rejection threshold. Since Counter-Attack only outputs nonpositive values, ' 'the detector will never reject an example.') if len(substitute_state_dict_paths) != len(counter_attack_names): raise click.BadArgumentUsage( 'substitute_state_dict_paths must be as many values as the number of counter attacks.' ) detector = parsing.parse_detector_pool( counter_attack_names, kwargs['domain'], kwargs['p'], 'defense', model, attack_config, kwargs['device'], use_substitute=True, substitute_state_dict_paths=substitute_state_dict_paths) defended_model = detectors.NormalisedDetectorModel( model, detector, kwargs['rejection_threshold']) # TODO: I parametri sono sbagliati evasion_pool = parsing.parse_attack_pool(kwargs['evasion_attacks'], kwargs['domain'], kwargs['p'], 'evasion', model, attack_config, kwargs['device'], defended_model=defended_model, seed=kwargs['seed']) adversarial_dataset = tests.attack_test(model, evasion_pool, dataloader, p, kwargs['misclassification_policy'], kwargs['device'], attack_config, dataset.start, dataset.stop, kwargs, defended_model) adversarial_dataset.print_stats() if kwargs['save_to'] is not None: utils.save_zip(adversarial_dataset, kwargs['save_to']) if kwargs['show'] is not None: utils.show_images(adversarial_dataset.genuines, adversarial_dataset.adversarials, limit=kwargs['show'], model=model)
def compare(**kwargs): parsing.set_log_level(kwargs['log_level']) if kwargs['deterministic']: if kwargs['seed'] is None: logger.warning( 'Determinism is enabled, but no seed has been provided.') utils.enable_determinism() if kwargs['cpu_threads'] is not None: torch.set_num_threads(kwargs['cpu_threads']) if kwargs['seed'] is not None: utils.set_seed(kwargs['seed']) model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], kwargs['state_dict_path'], False, kwargs['masked_relu'], False, load_weights=True) model.eval() if kwargs['indices_path'] is None: indices_override = None else: with open(kwargs['indices_path']) as f: indices_override = json.load(f) dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'], dataset_edges=(kwargs['start'], kwargs['stop']), indices_override=indices_override) dataloader = torch.utils.data.DataLoader(dataset, kwargs['batch_size'], shuffle=False) attack_config = utils.read_attack_config_file(kwargs['attack_config_file']) p = kwargs['p'] device = kwargs['device'] attack_names = kwargs['attacks'] attacks = [] for attack_name in attack_names: attack = parsing.parse_attack(attack_name, kwargs['domain'], p, 'standard', model, attack_config, device, seed=kwargs['seed']) attacks.append(attack) if kwargs['indices_path'] is None: start = kwargs['start'] stop = kwargs['stop'] else: start = None stop = None result_dataset = tests.multiple_attack_test( model, attack_names, attacks, dataloader, p, kwargs['misclassification_policy'], device, attack_config, start, stop, kwargs, indices_override=indices_override) if not kwargs['no_stats']: result_dataset.print_stats() if kwargs['show'] is not None: print() print('Showing top results.') best_results = result_dataset.simulate_pooling(attack_names) utils.show_images(best_results.genuines, best_results.adversarials, limit=kwargs['show'], model=model) for attack_name in attack_names: print(f'Showing results for {attack_name}.') attack_results = result_dataset.simulate_pooling([attack_name]) utils.show_images(attack_results.genuines, attack_results.adversarials, limit=kwargs['show'], model=model) if kwargs['save_to'] is not None: utils.save_zip(result_dataset, kwargs['save_to'])
def mip(**kwargs): command_start_timestamp = time.time() parsing.set_log_level(kwargs['log_level']) if kwargs['deterministic']: logger.warning('Determinism is not guaranteed for Gurobi.') if kwargs['seed'] is None: logger.warning( 'Determinism is enabled, but no seed has been provided.') utils.enable_determinism() if kwargs['cpu_threads'] is not None: torch.set_num_threads(kwargs['cpu_threads']) seed = kwargs['seed'] if seed is not None: utils.set_seed(kwargs['seed']) torch_model_retrieval_start_timestamp = time.time() model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], kwargs['state_dict_path'], False, kwargs['masked_relu'], False, load_weights=True) model.eval() dataset_retrieval_start_timestamp = torch_model_retrieval_end_timestamp = time.time( ) dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'], dataset_edges=(kwargs['start'], kwargs['stop'])) dataset_retrieval_end_timestamp = time.time() dataloader = torch.utils.data.DataLoader(dataset, kwargs['batch_size'], shuffle=False) if kwargs['pre_adversarial_dataset'] is None: pre_adversarial_dataset = None else: pre_adversarial_dataset = utils.load_zip( kwargs['pre_adversarial_dataset']) if isinstance(pre_adversarial_dataset, adversarial_dataset.AttackComparisonDataset): # Use the best results to compute an adversarial dataset pre_adversarial_dataset = pre_adversarial_dataset.to_adversarial_dataset( pre_adversarial_dataset.attack_names) p = kwargs['p'] if p == 2: metric = 'l2' elif np.isposinf(p): metric = 'linf' else: raise NotImplementedError(f'Unsupported metric "l{p}"') attack_config = utils.read_attack_config_file(kwargs['attack_config_file']) attack_kwargs = attack_config.get_arguments('mip', kwargs['domain'], metric, 'standard') attack_creation_start_timestamp = time.time() attack = attacks.MIPAttack(model, p, False, seed=seed, **attack_kwargs) attack_creation_end_timestamp = time.time() mip_dataset = tests.mip_test( model, attack, dataloader, p, kwargs['misclassification_policy'], kwargs['device'], attack_config, kwargs, start=dataset.start, stop=dataset.stop, pre_adversarial_dataset=pre_adversarial_dataset, gurobi_log_dir=kwargs['gurobi_log_dir']) mip_dataset.print_stats() command_end_timestamp = time.time() mip_dataset.global_extra_info['times']['command'] = { 'start_timestamp': command_start_timestamp, 'end_timestamp': command_end_timestamp } mip_dataset.global_extra_info['times']['torch_model_retrieval'] = { 'start_timestamp': torch_model_retrieval_start_timestamp, 'end_timestamp': torch_model_retrieval_end_timestamp } mip_dataset.global_extra_info['times']['dataset_retrieval'] = { 'start_timestamp': dataset_retrieval_start_timestamp, 'end_timestamp': dataset_retrieval_end_timestamp } mip_dataset.global_extra_info['times']['attack_creation'] = { 'start_timestamp': attack_creation_start_timestamp, 'end_timestamp': attack_creation_end_timestamp } if kwargs['save_to'] is not None: utils.save_zip(mip_dataset, kwargs['save_to']) if kwargs['show'] is not None: utils.show_images(mip_dataset.genuines, mip_dataset.adversarials, limit=kwargs['show'], model=model)
def prune_relu(**kwargs): parsing.set_log_level(kwargs['log_level']) if kwargs['deterministic']: if kwargs['seed'] is None: logger.warning( 'Determinism is enabled, but no seed has been provided.') utils.enable_determinism() logger.debug('Running attack command with kwargs %s.', kwargs) if kwargs['cpu_threads'] is not None: torch.set_num_threads(kwargs['cpu_threads']) if kwargs['seed'] is not None: utils.set_seed(kwargs['seed']) if kwargs['adversarial_ratio'] <= 0 or kwargs['adversarial_ratio'] > 1: raise click.BadArgumentUsage( 'adversarial_ratio', 'adversarial_ratio must be between 0 (exclusive) and 1 (inclusive).' ) device = kwargs['device'] model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], kwargs['original_state_dict_path'], False, False, False, load_weights=True) model.eval() model.to(device) if kwargs['threshold'] < 0 or kwargs['threshold'] > 1: raise ValueError('Threshold must be between 0 and 1 (inclusive).') if not isinstance(model, nn.Sequential): raise ValueError('This command only works with sequential networks.') if kwargs['dataset'] == 'std:test': logger.warning( 'This command is recommended to be used with non-test datasets.') if kwargs['threshold'] <= 0.5: raise click.BadArgumentUsage('threshold must be in (0.5, 1).') dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'], dataset_edges=(kwargs['start'], kwargs['stop'])) dataloader = torch.utils.data.DataLoader(dataset, kwargs['batch_size'], shuffle=False) attack_config = utils.read_attack_config_file(kwargs['attack_config_file']) attack_pool = parsing.parse_attack_pool(kwargs['attacks'], kwargs['domain'], kwargs['p'], 'training', model, attack_config, device, seed=kwargs['seed']) converted_model, total_relus, replaced_relus = pruning.prune_relu( model, dataloader, attack_pool, kwargs['adversarial_ratio'], kwargs['epsilon'], kwargs['threshold'], device) print(f'Replaced {replaced_relus} ReLUs out of {total_relus}.') save_to = kwargs['save_to'] pathlib.Path(save_to).parent.mkdir(parents=True, exist_ok=True) torch.save(converted_model.state_dict(), save_to)
def tune_mip(**kwargs): parsing.set_log_level(kwargs['log_level']) if kwargs['deterministic']: utils.enable_determinism() if not kwargs['save_to'].endswith('.prm'): raise click.BadArgumentUsage( 'save_to must have a .prm file extension.') model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], kwargs['state_dict_path'], False, kwargs['masked_relu'], False, load_weights=True) model.eval() attack_config = utils.read_attack_config_file(kwargs['attack_config_file']) attack = parsing.parse_attack('mip', kwargs['domain'], kwargs['p'], 'standard', model, attack_config, 'cpu', seed=kwargs['seed']) # TODO: model.cpu()? if kwargs['pre_adversarial_dataset'] is None: pre_adversarial_dataset = None else: pre_adversarial_dataset = utils.load_zip( kwargs['pre_adversarial_dataset']) if pre_adversarial_dataset.misclassification_policy != kwargs[ 'misclassification_policy']: raise ValueError( 'The misclassification policy of the pre-adversarial dataset does ' 'not match the given policy. This can produce incorrent starting points.' ) dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset']) # The misclassification policy "remove" messes with # indexing, so we apply it to the genuine dataset too if kwargs['misclassification_policy'] == 'remove': all_images = [] all_true_labels = [] for start in range(0, len(dataset), kwargs['batch_size']): stop = min(start + kwargs['batch_size'], len(dataset)) indices = range(start, stop) images = torch.stack([dataset[i][0] for i in indices]) true_labels = torch.stack( [torch.tensor(dataset[i][1]) for i in indices]) images, true_labels, _ = utils.apply_misclassification_policy( model, images, true_labels, 'remove') all_images += list(images) all_true_labels += list(true_labels) dataset = list(zip(all_images, all_true_labels)) if pre_adversarial_dataset is None: if kwargs['tuning_index'] == -1: tuning_index = np.random.randint(len(dataset)) else: tuning_index = kwargs['tuning_index'] pre_adversarial = None pre_image = None else: successful_indices = [ i for i in range(len(pre_adversarial_dataset)) if pre_adversarial_dataset.adversarials[i] is not None ] if kwargs['tuning_index'] == -1: tuning_index = np.random.choice(successful_indices) else: tuning_index = kwargs['tuning_index'] if tuning_index not in successful_indices: logger.warning( 'The chosen tuning_index does not have a matching ' 'pre-adversarial. Ignoring pre-adversarial optimizations.') pre_adversarial = pre_adversarial_dataset.adversarials[tuning_index] pre_adversarial = pre_adversarial.detach().cpu().numpy() pre_image = pre_adversarial_dataset.genuines[tuning_index] pre_image = pre_image.detach().cpu().numpy() image, label = dataset[tuning_index] image = image.detach().cpu().numpy() label = label.detach().cpu().item() if pre_image is not None and np.max(np.abs(image - pre_image)) > 1e-6: print(np.max(np.abs(image - pre_image))) raise RuntimeError( 'The pre-adversarial refers to a different genuine. ' 'This can slow down MIP at best and make it fail at worst. ' 'Are you sure that you\'re using the correct pre-adversarial dataset?' ) # Implicitly build the MIP model # TODO: Non ha senso avere un sistema di retry _, adversarial_result = attack.mip_attack( image, label, heuristic_starting_point=pre_adversarial) jump_model = adversarial_result['Model'] # Get the Gurobi model from julia import JuMP from julia import Gurobi from julia import Main gurobi_model = JuMP.internalmodel(jump_model).inner Gurobi.tune_model(gurobi_model) Main.model_pointer = gurobi_model Main.eval('Gurobi.get_tune_result!(model_pointer, 0)') # Save the model Gurobi.write_model(gurobi_model, kwargs['save_to'])
def cross_validation(**kwargs): parsing.set_log_level(kwargs['log_level']) if kwargs['deterministic']: if kwargs['seed'] is None: logger.warning( 'Determinism is enabled, but no seed has been provided.') utils.enable_determinism() if kwargs['cpu_threads'] is not None: torch.set_num_threads(kwargs['cpu_threads']) if kwargs['seed'] is not None: utils.set_seed(kwargs['seed']) model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], kwargs['state_dict_path'], False, kwargs['masked_relu'], False, load_weights=True) model.eval() model.to(kwargs['device']) dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'], dataset_edges=(kwargs['start'], kwargs['stop'])) dataloader = torch.utils.data.DataLoader(dataset, kwargs['batch_size'], shuffle=False) attack_config = utils.read_attack_config_file(kwargs['attack_config_file']) p = kwargs['p'] attack_names = kwargs['attacks'] rejection_thresholds = kwargs['rejection_thresholds'] substitute_state_dict_paths = kwargs['substitute_state_dict_paths'] if len(attack_names) < 2: raise click.BadArgumentUsage('attacks must be at least two.') if len(rejection_thresholds) == 1: rejection_thresholds = len(attack_names) * [rejection_thresholds[0]] if len(rejection_thresholds) != len(attack_names): raise click.BadArgumentUsage( 'rejection_thresholds must be either one value or as many values as the number of attacks.' ) if len(substitute_state_dict_paths) != len(attack_names): raise click.BadArgumentUsage( 'substitute_state_dict_paths must be as many values as the number of attacks.' ) if any(rejection_threshold > 0 for rejection_threshold in rejection_thresholds): logger.warning( 'Received a positive rejection threshold. Since Counter-Attack only outputs nonpositive values, ' 'the detector will never reject an example.') test_names = [] evasion_attacks = [] defended_models = [] for i in range(len(attack_names)): # Remove one attack from the pool. This attack will act as the evasion attack evasion_attack_name = attack_names[i] counter_attack_names = [ x for j, x in enumerate(attack_names) if j != i ] ca_substitute_state_dict_paths = [ x for j, x in enumerate(substitute_state_dict_paths) if j != i ] rejection_threshold = rejection_thresholds[i] detector = parsing.parse_detector_pool( counter_attack_names, kwargs['domain'], kwargs['p'], 'standard', model, attack_config, kwargs['device'], use_substitute=True, substitute_state_dict_paths=ca_substitute_state_dict_paths) defended_model = detectors.NormalisedDetectorModel( model, detector, rejection_threshold) evasion_attack = parsing.parse_attack(evasion_attack_name, kwargs['domain'], kwargs['p'], 'evasion', model, attack_config, kwargs['device'], defended_model=defended_model, seed=kwargs['seed']) test_name = f'{evasion_attack_name} vs {counter_attack_names}' test_names.append(test_name) evasion_attacks.append(evasion_attack) defended_models.append(defended_model) logger.info('Tests:\n{}'.format('\n'.join(test_names))) evasion_dataset = tests.multiple_evasion_test( model, test_names, evasion_attacks, defended_models, dataloader, p, kwargs['misclassification_policy'], kwargs['device'], attack_config, dataset.start, dataset.stop, kwargs) if kwargs['save_to'] is not None: utils.save_zip(evasion_dataset, kwargs['save_to']) for test_name in test_names: print(f'Test "{test_name}":') adversarial_dataset = evasion_dataset.to_adversarial_dataset(test_name) adversarial_dataset.print_stats() if kwargs['show'] is not None: utils.show_images(adversarial_dataset.genuines, adversarial_dataset.adversarials, limit=kwargs['show'], model=model)
def train_classifier(**kwargs): parsing.set_log_level(kwargs['log_level']) logger.debug('Running train-classifier command with kwargs %s', kwargs) if kwargs['deterministic']: if kwargs['seed'] is None: logger.warning( 'Determinism is enabled, but no seed has been provided.') utils.enable_determinism() if kwargs['cpu_threads'] is not None: torch.set_num_threads(kwargs['cpu_threads']) if kwargs['seed'] is not None: utils.set_seed(kwargs['seed']) model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], None, False, kwargs['masked_relu'], True, load_weights=False) model.train() extra_transforms = [] if kwargs['flip']: extra_transforms.append(torchvision.transforms.RandomHorizontalFlip()) if kwargs['rotation'] != 0 or kwargs['translation'] != 0: if kwargs['translation'] < 0 or kwargs['translation'] > 1: logger.warning('The suggested range for --translation is [0, 1].') if kwargs['rotation'] < 0 or kwargs['rotation'] > 180: logger.warning('The suggested range for --rotation is [0, 180].') translation = ( kwargs['translation'], kwargs['translation']) if kwargs['translation'] != 0 else None extra_transforms.append( torchvision.transforms.RandomAffine(kwargs['rotation'], translation)) train_dataset = parsing.parse_dataset(kwargs['domain'], kwargs['dataset'], extra_transforms=extra_transforms) # Validation val_dataset = None if kwargs['validation_dataset'] is not None and kwargs[ 'validation_split'] != 0: raise click.BadOptionUsage( '--validation_split', '--validation_dataset and validation_split are mutually exclusive.' ) if kwargs['validation_split'] != 0: logger.debug('Performing a validation split.') train_dataset, val_dataset = training.split_dataset( train_dataset, kwargs['validation_split'], shuffle=True) elif kwargs['validation_dataset'] is not None: logger.debug('Loading an existing validation dataset.') val_dataset = parsing.parse_dataset(kwargs['domain'], kwargs['validation_dataset'], allow_standard=True) train_dataloader = torch.utils.data.DataLoader(train_dataset, kwargs['batch_size'], shuffle=kwargs['shuffle']) if val_dataset is None: val_dataloader = None else: val_dataloader = torch.utils.data.DataLoader(val_dataset, kwargs['batch_size'], shuffle=False) # Early stopping early_stopping = None if kwargs['early_stopping'] > 0: if kwargs['choose_best'] and kwargs['early_stopping_delta'] != 0: logger.warning( 'Received --choose-best and --early-stopping with delta != 0. ' 'Remember that with delta != 0, --choose-best and --early-stopping ' 'track differently the best loss and state_dict.') logger.debug('Adding early stopping.') early_stopping = training.EarlyStopping( kwargs['early_stopping'], delta=kwargs['early_stopping_delta']) # Adversarial training if kwargs['adversarial_training'] == []: adversarial_attack = None if kwargs['adversarial_ratio'] is not None: logger.warning( 'Received --adversarial-ratio without --adversarial-training.') if kwargs['adversarial_p'] is not None: logger.warning( 'Received --adversarial-p without --adversarial-training.') if kwargs['adversarial_eps'] is not None: logger.warning( 'Received --adversarial-eps without --adversarial-training.') if kwargs['adversarial_eps_growth_epoch'] != 0: logger.warning( 'Received --adversarial-eps-growth-epoch without --adversarial-training.' ) if kwargs['adversarial_eps_growth_start'] is not None: logger.warning( 'Received --adversarial-eps-growth-start without --adversarial-training.' ) else: logger.debug('Enabling adversarial training.') if kwargs['adversarial_ratio'] is None: raise click.BadOptionUsage( '--adversarial-ratio', 'Please specify the ratio for adversarial training with --adversarial-ratio.' ) if kwargs['adversarial_ratio'] <= 0 or kwargs['adversarial_ratio'] > 1: raise click.BadOptionUsage( '--adversarial-ratio', '--adversarial-ratio must be between 0 (exclusive) and 1 (inclusive).' ) if kwargs['adversarial_p'] is None: raise click.BadOptionUsage( '--adversarial-p', 'Please specify the Lp norm for adversarial training with --adversarial-p.' ) if kwargs['adversarial_eps'] is None: raise click.BadOptionUsage( '--adversarial-eps', 'Please specify the maximum perturbarion norm for adversarial training with --adversarial-eps (inf is also allowed).' ) if kwargs['adversarial_eps_growth_epoch'] > 0: if kwargs['adversarial_eps_growth_start'] is None: raise click.BadOptionUsage( '--adversarial-eps-growth-start', 'Please specify the initial value for adversarial epsilon growth with --adversarial-eps-growth-start ' '(0 is also allowed).') if kwargs['early_stopping'] > 0: logger.warning( 'Received --adversarial-eps-growth-epoch and --early-stopping together.' ) elif kwargs['adversarial_eps_growth_start'] is not None: logger.warning( 'Received --adversarial-eps-growth-start without --adversarial-eps-growth-epoch.' ) attack_config = utils.read_attack_config_file( kwargs['attack_config_file']) adversarial_attack = parsing.parse_attack_pool( kwargs['adversarial_training'], kwargs['domain'], kwargs['adversarial_p'], 'training', model, attack_config, kwargs['device'], seed=kwargs['seed']) # RS loss if kwargs['rs_regularization'] == 0: if kwargs['rs_eps'] is not None: logger.warning('Received --rs-eps without --rs-regularization.') if kwargs['rs_start_epoch'] != 1: logger.warning( 'Received --rs-start_epoch without --rs-regularization.') else: if kwargs['rs_eps'] is None: raise click.BadOptionUsage( '--rs-eps', 'Please specify the maximum perturbation for RS loss with --rs-eps.' ) if kwargs['rs_start_epoch'] > kwargs['epochs']: logger.warning( '--rs-start-epoch is higher than the number of epochs. This means that RS loss will never be activated.' ) if kwargs['rs_start_epoch'] > 1 and kwargs['early_stopping'] > 0: logger.warning( 'Received --rs-start-epoch and --early-stopping together.') # Use Mean Cross Entropy, consistent with Xiao and Madry's ReLU training technique loss = torch.nn.CrossEntropyLoss(reduction='mean') optimiser = parsing.parse_optimiser(kwargs['optimiser'], model.parameters(), kwargs) if kwargs['checkpoint_every'] is None: checkpoint_path = None else: checkpoint_path = kwargs['save_to'] + '-checkpoint' if kwargs['load_checkpoint'] is None: loaded_checkpoint = None else: loaded_checkpoint = utils.torch_load(kwargs['load_checkpoint']) training.train( model, train_dataloader, optimiser, loss, kwargs['epochs'], kwargs['device'], val_loader=val_dataloader, l1_regularization=kwargs['l1_regularization'], rs_regularization=kwargs['rs_regularization'], rs_eps=kwargs['rs_eps'], rs_minibatch_size=kwargs['rs_minibatch'], rs_start_epoch=kwargs['rs_start_epoch'], early_stopping=early_stopping, attack=adversarial_attack, attack_ratio=kwargs['adversarial_ratio'], attack_p=kwargs['adversarial_p'], attack_eps=kwargs['adversarial_eps'], attack_eps_growth_epoch=kwargs['adversarial_eps_growth_epoch'], attack_eps_growth_start=kwargs['adversarial_eps_growth_start'], checkpoint_every=kwargs['checkpoint_every'], checkpoint_path=checkpoint_path, loaded_checkpoint=loaded_checkpoint, choose_best=kwargs['choose_best']) save_to = kwargs['save_to'] pathlib.Path(save_to).parent.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), save_to)
def distance_dataset(**kwargs): parsing.set_log_level(kwargs['log_level']) if kwargs['deterministic']: if kwargs['seed'] is None: logger.warning( 'Determinism is enabled, but no seed has been provided.') utils.enable_determinism() if kwargs['cpu_threads'] is not None: torch.set_num_threads(kwargs['cpu_threads']) if kwargs['seed'] is not None: utils.set_seed(kwargs['seed']) model = parsing.parse_model(kwargs['domain'], kwargs['architecture'], kwargs['state_dict_path'], False, kwargs['masked_relu'], False, load_weights=True) model.eval() model.to(kwargs['device']) attack_config = utils.read_attack_config_file(kwargs['attack_config_file']) attack_pool = parsing.parse_attack_pool(kwargs['attacks'], kwargs['domain'], kwargs['p'], 'standard', model, attack_config, kwargs['device'], seed=kwargs['seed']) p = kwargs['p'] if kwargs['from_genuine'] is None and kwargs['from_adversarial'] is None: raise RuntimeError( 'At least one among --from-genuine and --from-adversarial must be provided.' ) images = [] distances = [] if kwargs['from_genuine'] is not None: genuine_dataset = parsing.parse_dataset(kwargs['domain'], kwargs['from_genuine'], dataset_edges=(kwargs['start'], kwargs['stop'])) genuine_loader = torch.utils.data.DataLoader(genuine_dataset, kwargs['batch_size'], shuffle=False) # TODO: I parametri sono tutti sbagliati genuine_result_dataset = tests.attack_test( model, attack_pool, genuine_loader, p, kwargs['misclassification_policy'], kwargs['device'], attack_config, genuine_dataset.start, genuine_dataset.stop, kwargs, None) images += list(genuine_result_dataset.genuines) distances += list(genuine_result_dataset.distances) if kwargs['from_adversarial'] is not None: adversarial_dataset = parsing.parse_dataset( kwargs['domain'], kwargs['from_adversarial'], allow_standard=False, dataset_edges=(kwargs['start'], kwargs['stop'])) adv_start, adv_stop = adversarial_dataset.start, adversarial_dataset.stop # Get the labels for the adversarial samples adversarial_dataset = utils.create_label_dataset( model, adversarial_dataset.adversarials, kwargs['batch_size']) adversarial_loader = torch.utils.data.DataLoader(adversarial_dataset, kwargs['batch_size'], shuffle=False) # TODO: I parametri sono sbagliati adversarial_result_dataset = tests.attack_test( model, attack_pool, adversarial_loader, p, False, kwargs['device'], attack_config, adv_start, adv_stop, kwargs, None) images += list(adversarial_result_dataset.genuines) distances += list(adversarial_result_dataset.distances) images = torch.stack(images) distances = torch.stack(distances) final_dataset = ad.AdversarialDistanceDataset(images, distances) utils.save_zip(final_dataset, kwargs['save_to'])