from typing import Dict, List import torch import csv import argparse from perceptual_advex.utilities import add_dataset_model_arguments, \ get_dataset_model from perceptual_advex.attacks import * if __name__ == '__main__': parser = argparse.ArgumentParser( description='Adversarial training evaluation') add_dataset_model_arguments(parser, include_checkpoint=True) parser.add_argument('attacks', metavar='attack', type=str, nargs='+', help='attack names') parser.add_argument('--batch_size', type=int, default=100, help='number of examples/minibatch') parser.add_argument('--parallel', type=int, default=1, help='number of GPUs to train on') parser.add_argument('--num_batches', type=int, required=False, help='number of batches (default entire dataset)') parser.add_argument('--per_example', action='store_true', default=False, help='output per-example accuracy') parser.add_argument('--output', type=str, help='output CSV') args = parser.parse_args() dataset, model = get_dataset_model(args)
from perceptual_advex import evaluation from perceptual_advex.utilities import add_dataset_model_arguments, \ get_dataset_model, calculate_accuracy, get_vae_model from perceptual_advex.attacks import * from perceptual_advex.ci_attacks2 import * from perceptual_advex.cd_attacks2 import * from perceptual_advex.models import FeatureModel from perceptual_advex.hidden_attacks import * from perceptual_advex.distances import L2Distance, LinfDistance VAL_ITERS = 100 if __name__ == '__main__': parser = argparse.ArgumentParser() add_dataset_model_arguments(parser) parser.add_argument('--num_epochs', type=int, required=False, help='number of epochs trained') parser.add_argument('--batch_size', type=int, default=100, help='number of examples/minibatch') parser.add_argument('--val_batches', type=int, default=10, help='number of batches to validate on') parser.add_argument('--log_dir', type=str, default='data/logs') parser.add_argument('--parallel', type=int, default=1, help='number of GPUs to train on') parser.add_argument('--only_attack_correct', action='store_true', default=False, help='only attack examples that ' 'are classified correctly')