def __main__(): model_structures = { "mnist_cnn": [mnist_cnn, mnist_testloader], "mnist": [mnist_resnet18, mnist_testloader], "cifar10_cnn": [cifar10_cnn, cifar10_testloader], "cifar10_wide_resnet34_10": [ cifar10_wide_resnet34_10, cifar10_testloader, ], } parser = argparse.ArgumentParser( description="Evaluate models against attacked perturbations") parser.add_argument("arch", type=str) parser.add_argument("state_dict", type=str) parser.add_argument("--device", default="cuda", type=str) args = parser.parse_args() arch = args.arch state_path = args.state_dict device = args.device model, testloader = model_structures[arch] state = torch.load(state_path) model.load_state_dict(state) print(f"Testing {arch} with {state_path} on {device}\n") print(f"Unattacked {arch}") print(classification_report(model, testloader, device=device)) print(f"FGSM attacked {arch}") print(classification_report_fgsm(model, testloader, device=device)) print(f"PGD attacked {arch}") print(classification_report_pgd(model, testloader, device=device))
mnist_resnet18.load_state_dict(mnist_resnet18_state) models = { "MNIST ResNet18": [mnist_resnet18, mnist_trainloader, mnist_testloader], } # %% for model_name, (model, trainloader, testloader) in models.items(): logging.info(f"Training {model_name}") new_model = fgsm_training(model, trainloader, device="cuda", log=log, n_epoches=40, random=True) torch.save( model.state_dict(), os.path.join(SCRIPT_PATH, f"FGSM {model_name}.model"), ) logging.info(f"Unattacked {model_name}") logging.info(classification_report(model, testloader, device="cuda")) logging.info(f"FGSM attacked {model_name}") logging.info(classification_report_fgsm(model, testloader, device="cuda")) logging.info(f"PGD attacked {model_name}") logging.info(classification_report_pgd(model, testloader, device="cuda")) # %%
model, trainloader, device=DEVICE, log=log, n_clusters=n_clusters, cluster_with=cluster_with, **global_param, ) torch.save( new_model.state_dict(), os.path.join( SCRIPT_PATH, f"Cluster {model_name} {cluster_with} {n_clusters}.model", ), ) logging.info(f"Unattacked {model_name}") logging.info( classification_report(new_model, testloader, device=DEVICE)) logging.info(f"FGSM attacked {model_name}") logging.info( classification_report_fgsm(new_model, testloader, device=DEVICE)) logging.info(f"PGD attacked {model_name}") logging.info( classification_report_pgd(new_model, testloader, device=DEVICE))