def main(args): """Main function to test a model input: args: Argparse object that contains all the parsed values return: void """ print(now(), "test_model.py main() running.") test_log = "clean_test_log.txt" to_log_file(args, args.output, test_log) # Set device device = "cuda" if torch.cuda.is_available() else "cpu" #################################################### # Dataset if args.dataset.lower() == "cifar10": transform_train = get_transform(args.normalize, args.train_augment) transform_test = get_transform(args.normalize, False) trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128) testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) elif args.dataset.lower() == "cifar100": transform_train = get_transform(args.normalize, args.train_augment) transform_test = get_transform(args.normalize, False) trainset = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128) testset = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) elif args.dataset.lower() == "tinyimagenet_first": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="firsthalf", ) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="firsthalf", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) elif args.dataset.lower() == "tinyimagenet_last": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="lasthalf", ) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="lasthalf", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) elif args.dataset.lower() == "tinyimagenet_all": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="all", ) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="all", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) else: print("Dataset not yet implemented. Exiting from test_model.py.") sys.exit() #################################################### #################################################### # Network and Optimizer net = get_model(args.model, args.dataset) # load model from path if a path is provided if args.model_path is not None: net = load_model_from_checkpoint(args.model, args.model_path, args.dataset) else: print( "No model path provided, continuing test with untrained network.") net = net.to(device) #################################################### #################################################### # Test Model training_acc = test(net, trainloader, device) natural_acc = test(net, testloader, device) print(now(), " Training accuracy: ", training_acc) print(now(), " Natural accuracy: ", natural_acc) stats = OrderedDict([ ("model path", args.model_path), ("model", args.model), ("normalize", args.normalize), ("augment", args.train_augment), ("training_acc", training_acc), ("natural_acc", natural_acc), ]) to_results_table(stats, args.output, "clean_performance.csv") #################################################### return
def main(): print("\n_________________________________________________\n") print(now(), "train_model.py main() running.") parser = argparse.ArgumentParser(description='Poisoning Benchmark') parser.add_argument('--lr', default=0.01, type=float, help='learning rate') parser.add_argument('--lr_schedule', nargs='+', default=[100, 150], type=int, help='how often to decrease lr') parser.add_argument('--lr_factor', default=0.1, type=float, help='factor by which to decrease lr') parser.add_argument('--epochs', default=200, type=int, help='number of epochs for training') parser.add_argument('--optimizer', default='SGD', type=str, help='optimizer') parser.add_argument('--model', default='ResNet18', type=str, help='model for training') parser.add_argument('--dataset', default='imagenet', type=str, help='dataset') parser.add_argument('--val_period', default=1, type=int, help='print every __ epoch') parser.add_argument('--output', default='output_default', type=str, help='output subdirectory') parser.add_argument('--checkpoint', default='check_default', type=str, help='where to save the network') parser.add_argument('--model_path', default='', type=str, help='where is the model saved?') parser.add_argument('--save_net', action='store_true', help='save net?') parser.add_argument('--seed', default=0, type=int, help='seed for seeding random processes.') parser.add_argument('--normalize', dest='normalize', action='store_true') parser.add_argument('--batch_size', default=128, type=int, help='Batch size for training') parser.add_argument('--no-normalize', dest='normalize', action='store_false') parser.set_defaults(normalize=True) parser.add_argument('--train_augment', dest='train_augment', action='store_true') parser.add_argument('--no-train_augment', dest='train_augment', action='store_false') parser.set_defaults(train_augment=False) parser.add_argument('--test_augment', dest='test_augment', action='store_true') parser.add_argument('--no-test_augment', dest='test_augment', action='store_false') parser.set_defaults(test_augment=False) args = parser.parse_args() np.random.seed(args.seed) torch.manual_seed(args.seed) train_log = "train_log_{}.txt".format(args.model) to_log_file(args, args.output, train_log) # Set device device = 'cuda' if torch.cuda.is_available() else 'cpu' #################################################### # Dataset transform_train = get_transform(args.normalize, args.train_augment, dataset="imagenet") transform_test = get_transform(args.normalize, args.test_augment, dataset="imagenet") trainset = torchvision.datasets.ImageFolder( "/cmlscratch/arjgpt27/projects/ENPM673/DL/dataset/train", transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, num_workers=4, shuffle=True) testset = torchvision.datasets.ImageFolder( "/cmlscratch/arjgpt27/projects/ENPM673/DL/dataset/val", transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, num_workers=4, shuffle=False) #################################################### #################################################### # Network and Optimizer net = get_model(args.model) to_log_file(net, args.output, train_log) net = net.to(device) start_epoch = 0 if args.optimizer == "SGD": optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=2e-4) elif args.optimizer == "adam": optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=2e-4) criterion = nn.CrossEntropyLoss() if args.model_path != '': print("loading model from path: ", args.model_path) state_dict = torch.load(args.model_path, map_location=device) net.load_state_dict(state_dict['net']) optimizer.load_state_dict(state_dict['optimizer']) start_epoch = state_dict['epoch'] #################################################### #################################################### # Train and Test print("==> Training network...") loss = 0 all_losses = [] all_acc = [] all_losses_test = [] all_acc_test = [] epoch = start_epoch for epoch in tqdm(range(start_epoch, args.epochs)): adjust_learning_rate(optimizer, epoch, args.lr_schedule, args.lr_factor) loss, acc = train(net, trainloader, optimizer, criterion, device) all_losses.append(loss) all_acc.append(acc) if args.save_net and (epoch + 1) % 10 == 0: state = { 'net': net.state_dict(), 'epoch': epoch, 'optimizer': optimizer.state_dict() } out_str = os.path.join( args.checkpoint, args.model + '_seed_' + str(args.seed) + '_normalize_' + str(args.normalize) + '_augment_' + str(args.train_augment) + '_optimizer_' + str(args.optimizer) + '_epoch_' + str(epoch) + '.t7') print('saving model to: ', args.checkpoint, ' out_str: ', out_str) if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) torch.save(state, out_str) if (epoch + 1) % args.val_period == 0: print("Epoch: ", epoch) print("Loss: ", loss) print("Training acc: ", acc) natural_acc, test_loss = test(net, testloader, device, criterion) all_losses_test.append(test_loss) all_acc_test.append(natural_acc) print(now(), " Natural accuracy: ", natural_acc, "Test Loss: ", test_loss) to_log_file( { "epoch": epoch, "loss": loss, "training_acc": acc, "natural_acc": natural_acc }, args.output, train_log) # test # natural_acc, test_loss = test(net, testloader, device, criterion) # all_losses_test.append(test_loss) # all_acc_test.append(natural_acc) # print(now(), " Natural accuracy: ", natural_acc) to_log_file({ "epoch": epoch, "loss": loss, "natural_acc": natural_acc }, args.output, train_log) #################################################### #################################################### # Save if args.save_net: state = { 'net': net.state_dict(), 'epoch': epoch, 'optimizer': optimizer.state_dict() } out_str = os.path.join( args.checkpoint, args.model + '_seed_' + str(args.seed) + '_normalize_' + str(args.normalize) + '_augment_' + str(args.train_augment) + '_optimizer_' + str(args.optimizer) + '_epoch_' + str(epoch) + '.t7') print('saving model to: ', args.checkpoint, ' out_str: ', out_str) if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) torch.save(state, out_str) # plot_loss(all_losses, args) #################################################### total_epochs = np.arange(start_epoch, args.epochs) filename = './plots/training_acc.png' plot_fig(total_epochs, all_acc, filename) filename = './plots/test_acc.png' plot_fig(total_epochs, all_acc_test, filename) filename = './plots/train_loss.png' plot_fig(total_epochs, all_losses, filename) filename = './plots/test_loss.png' plot_fig(total_epochs, all_losses_test, filename) return
def main(args): """Main function to check the success rate of the given poisons input: args: Argparse object return: void """ print(now(), "poison_test.py main() running.") test_log = "poison_test_log.txt" to_log_file(args, args.output, test_log) lr = args.lr # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # load the poisons and their indices within the training set from pickled files with open(os.path.join(args.poisons_path, "poisons.pickle"), "rb") as handle: poison_tuples = pickle.load(handle) print(len(poison_tuples), " poisons in this trial.") poisoned_label = poison_tuples[0][1] with open(os.path.join(args.poisons_path, "base_indices.pickle"), "rb") as handle: poison_indices = pickle.load(handle) # get the dataset and the dataloaders trainloader, testloader, dataset, transform_train, transform_test, num_classes = \ get_dataset(args, poison_tuples, poison_indices) # get the target image from pickled file with open(os.path.join(args.poisons_path, "target.pickle"), "rb") as handle: target_img_tuple = pickle.load(handle) target_class = target_img_tuple[1] if len(target_img_tuple) == 4: patch = target_img_tuple[2] if torch.is_tensor(target_img_tuple[2]) else \ torch.tensor(target_img_tuple[2]) if patch.shape[0] != 3 or patch.shape[1] != args.patch_size or \ patch.shape[2] != args.patch_size: print( f"Expected shape of the patch is [3, {args.patch_size}, {args.patch_size}] " f"but is {patch.shape}. Exiting from poison_test.py.") sys.exit() startx, starty = target_img_tuple[3] target_img_pil = target_img_tuple[0] h, w = target_img_pil.size if starty + args.patch_size > h or startx + args.patch_size > w: print( "Invalid startx or starty point for the patch. Exiting from poison_test.py." ) sys.exit() target_img_tensor = transforms.ToTensor()(target_img_pil) target_img_tensor[:, starty:starty + args.patch_size, startx:startx + args.patch_size] = patch target_img_pil = transforms.ToPILImage()(target_img_tensor) else: target_img_pil = target_img_tuple[0] target_img = transform_test(target_img_pil) poison_perturbation_norms = compute_perturbation_norms( poison_tuples, dataset, poison_indices) # the limit is '8/255' but we assert that it is smaller than 9/255 to account for PIL # truncation. assert max( poison_perturbation_norms) - 9 / 255 < 1e-5, "Attack not clean label!" #################################################### #################################################### # Network and Optimizer # load model from path if a path is provided if args.model_path is not None: net = load_model_from_checkpoint(args.model, args.model_path, args.pretrain_dataset) else: args.ffe = False # we wouldn't fine tune from a random intiialization net = get_model(args.model, args.dataset) # freeze weights in feature extractor if not doing from scratch retraining if args.ffe: for param in net.parameters(): param.requires_grad = False # reinitialize the linear layer num_ftrs = net.linear.in_features net.linear = nn.Linear(num_ftrs, num_classes) # requires grad by default # set optimizer if args.optimizer.upper() == "SGD": optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=args.weight_decay, momentum=0.9) elif args.optimizer.upper() == "ADAM": optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss() #################################################### #################################################### # Poison and Train and Test print("==> Training network...") epoch = 0 for epoch in range(args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_schedule, args.lr_factor) loss, acc = train(net, trainloader, optimizer, criterion, device, train_bn=not args.ffe) if (epoch + 1) % args.val_period == 0: natural_acc = test(net, testloader, device) net.eval() p_acc = (net(target_img.unsqueeze(0).to(device)).max(1)[1].item() == poisoned_label) print( now(), " Epoch: ", epoch, ", Loss: ", loss, ", Training acc: ", acc, ", Natural accuracy: ", natural_acc, ", poison success: ", p_acc, ) # test natural_acc = test(net, testloader, device) print(now(), " Training ended at epoch ", epoch, ", Natural accuracy: ", natural_acc) net.eval() p_acc = net( target_img.unsqueeze(0).to(device)).max(1)[1].item() == poisoned_label print( now(), " poison success: ", p_acc, " poisoned_label: ", poisoned_label, " prediction: ", net(target_img.unsqueeze(0).to(device)).max(1)[1].item(), ) # Dictionary to write contest the csv file stats = OrderedDict([ ("poisons path", args.poisons_path), ("model", args.model_path if args.model_path is not None else args.model), ("target class", target_class), ("base class", poisoned_label), ("num poisons", len(poison_tuples)), ("max perturbation norm", np.max(poison_perturbation_norms)), ("epoch", epoch), ("loss", loss), ("training_acc", acc), ("natural_acc", natural_acc), ("poison_acc", p_acc), ]) to_results_table(stats, args.output) #################################################### return
def main(args): """Main function to generate the CP poisons inputs: args: Argparse object reutrn: void """ print(now(), "craft_poisons_bp.py main() running.") craft_log = "craft_log.txt" to_log_file(args, args.output, craft_log) #################################################### # Dataset if args.dataset.lower() == "cifar10": transform_test = get_transform(args.normalize, False) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform_test ) trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform_test ) elif args.dataset.lower() == "tinyimagenet_first": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="firsthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="firsthalf", ) elif args.dataset.lower() == "tinyimagenet_last": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="lasthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="lasthalf", ) elif args.dataset.lower() == "tinyimagenet_all": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="all", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="all", ) else: print("Dataset not yet implemented. Exiting from craft_poisons_cp.py.") sys.exit() ################################################### #################################################### # Find target and base images with open(args.poison_setups, "rb") as handle: setup_dicts = pickle.load(handle) setup = setup_dicts[args.setup_idx] target_img_idx = ( setup["target index"] if args.target_img_idx is None else args.target_img_idx ) base_indices = ( setup["base indices"] if args.base_indices is None else args.base_indices ) # get single target target_img, target_label = testset[target_img_idx] # get multiple bases base_imgs = torch.stack([trainset[i][0] for i in base_indices]) base_labels = [trainset[i][1] for i in base_indices] poisoned_label = base_labels[0] # log target and base details to_log_file("base indices: " + str(base_indices), args.output, craft_log) to_log_file("base labels: " + str(base_labels), args.output, craft_log) to_log_file("target_label: " + str(target_label), args.output, craft_log) to_log_file("target_index: " + str(target_img_idx), args.output, craft_log) # Set visible CUDA devices cudnn.benchmark = True # load the pre-trained models sub_net_list = [] for n_model, chk_name in enumerate(args.model_path): net = load_model_from_checkpoint( args.model[n_model], chk_name, args.pretrain_dataset ) sub_net_list.append(net) target_net = load_model_from_checkpoint( args.target_model, args.target_model_path, args.pretrain_dataset ) # Get the target image target = target_img.unsqueeze(0) chk_path = args.poisons_path if not os.path.exists(chk_path): os.makedirs(chk_path) base_tensor_list = [base_imgs[i] for i in range(base_imgs.shape[0])] base_tensor_list = [bt.to("cuda") for bt in base_tensor_list] poison_init = base_tensor_list mean, std = data_mean_std_dict[args.dataset.lower()] poison_tuple_list = make_convex_polytope_poisons( sub_net_list, target_net, base_tensor_list, target, "cuda", opt_method=args.poison_opt, lr=args.poison_lr, momentum=args.poison_momentum, iterations=args.crafting_iters, epsilon=args.epsilon, decay_ites=args.poison_decay_ites, decay_ratio=args.poison_decay_ratio, mean=torch.Tensor(mean).reshape(1, 3, 1, 1), std=torch.Tensor(std).reshape(1, 3, 1, 1), chk_path=chk_path, poison_idxes=base_indices, poison_label=poisoned_label, tol=args.tol, start_ite=0, poison_init=poison_init, end2end=args.end2end, mode="mean", ) # move poisons to PIL format if args.normalize: target = un_normalize_data(target.squeeze(0), args.dataset) for i in range(len(poison_tuple_list)): poison_tuple_list[i] = ( transforms.ToPILImage()( un_normalize_data(poison_tuple_list[i][0], args.dataset) ), poison_tuple_list[i][1], ) else: target = target.squeeze(0) for i in range(len(poison_tuple_list)): poison_tuple_list[i] = ( transforms.ToPILImage()(poison_tuple_list[i][0]), poison_tuple_list[i][1], ) # get perturbation norms poison_perturbation_norms = [] for idx, (poison_tensor, p_label) in enumerate(poison_tuple_list): poison_perturbation_norms.append( torch.max( torch.abs( transforms.ToTensor()(poison_tensor) - un_normalize_data(base_tensor_list[idx].cpu(), args.dataset) ) ).item() ) to_log_file( "perturbation norms: " + str(poison_perturbation_norms), args.output, "craft_log.txt", ) #################################################### # Save Poisons print(now(), "Saving poisons...") if not os.path.isdir(args.poisons_path): os.makedirs(args.poisons_path) with open(os.path.join(args.poisons_path, "poisons.pickle"), "wb") as handle: pickle.dump(poison_tuple_list, handle, protocol=pickle.HIGHEST_PROTOCOL) with open( os.path.join(args.poisons_path, "perturbation_norms.pickle"), "wb" ) as handle: pickle.dump(poison_perturbation_norms, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "base_indices.pickle"), "wb") as handle: pickle.dump(base_indices, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "target.pickle"), "wb") as handle: pickle.dump( (transforms.ToPILImage()(target), target_label), handle, protocol=pickle.HIGHEST_PROTOCOL, ) to_log_file("poisons saved.", args.output, "craft_log.txt") #################################################### print(now(), "craft_poisons_bp.py done.") return
def main(args): """Main function to generate the FC poisons inputs: args: Argparse object reutrn: void """ print(now(), "craft_poisons_fc.py main() running.") craft_log = "craft_log.txt" to_log_file(args, args.output, craft_log) # Set device device = "cuda" if torch.cuda.is_available() else "cpu" #################################################### # Dataset if args.dataset.lower() == "cifar10": transform_test = get_transform(args.normalize, False) trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_test) testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test) elif args.dataset.lower() == "tinyimagenet_first": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="firsthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="firsthalf", ) elif args.dataset.lower() == "tinyimagenet_last": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="lasthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="lasthalf", ) elif args.dataset.lower() == "tinyimagenet_all": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="all", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="all", ) else: print("Dataset not yet implemented. Exiting from craft_poisons_fc.py.") sys.exit() ################################################### #################################################### # Craft and insert poison image feature_extractors = [] for i in range(len(args.model)): feature_extractors.append( load_model_from_checkpoint(args.model[i], args.model_path[i], args.pretrain_dataset)) for i in range(len(feature_extractors)): for param in feature_extractors[i].parameters(): param.requires_grad = False feature_extractors[i].eval() feature_extractors[i] = feature_extractors[i].to(device) with open(args.poison_setups, "rb") as handle: setup_dicts = pickle.load(handle) setup = setup_dicts[args.setup_idx] target_img_idx = (setup["target index"] if args.target_img_idx is None else args.target_img_idx) base_indices = (setup["base indices"] if args.base_indices is None else args.base_indices) # Craft poisons poison_iterations = args.crafting_iters poison_perturbation_norms = [] # get single target target_img, target_label = testset[target_img_idx] # get multiple bases base_imgs = torch.stack([trainset[i][0] for i in base_indices]) base_labels = torch.LongTensor([trainset[i][1] for i in base_indices]) # log target and base details to_log_file("base indices: " + str(base_indices), args.output, craft_log) to_log_file("base labels: " + str(base_labels), args.output, craft_log) to_log_file("target_label: " + str(target_label), args.output, craft_log) to_log_file("target_index: " + str(target_img_idx), args.output, craft_log) # fill list of tuples of poison images and labels poison_tuples = [] target_img = (un_normalize_data(target_img, args.dataset) if args.normalize else target_img) beta = 4.0 if args.normalize else 0.1 base_tuples = list(zip(base_imgs, base_labels)) for base_img, label in base_tuples: # unnormalize the images for optimization b_unnormalized = (un_normalize_data(base_img, args.dataset) if args.normalize else base_img) objective_vals = [10e8] step_size = args.step_size # watermarking x = copy.deepcopy(b_unnormalized) x = args.watermark_coeff * target_img + (1 - args.watermark_coeff) * x # feature collision optimization done_with_fc = False i = 0 while not done_with_fc and i < poison_iterations: x.requires_grad = True if args.normalize: mini_batch = torch.stack([ normalize_data(x, args.dataset), normalize_data(target_img, args.dataset), ]).to(device) else: mini_batch = torch.stack([x, target_img]).to(device) loss = 0 for feature_extractor in feature_extractors: feats = feature_extractor.penultimate(mini_batch) loss += torch.norm(feats[0, :] - feats[1, :])**2 grad = torch.autograd.grad(loss, [x])[0] x_hat = x.detach() - step_size * grad.detach() if not args.l2: pert = (x_hat - b_unnormalized).clamp(-args.epsilon, args.epsilon) x_new = b_unnormalized + pert x_new = x_new.clamp(0, 1) obj = loss else: x_new = (x_hat.detach() + step_size * beta * b_unnormalized.detach()) / (1 + step_size * beta) x_new = x_new.clamp(0, 1) obj = beta * torch.norm(x_new - b_unnormalized)**2 + loss if obj > objective_vals[-1]: step_size *= 0.2 else: if torch.norm(x - x_new) / torch.norm(x) < 1e-5: done_with_fc = True x = copy.deepcopy(x_new) objective_vals.append(obj) i += 1 poison_tuples.append((transforms.ToPILImage()(x), label.item())) poison_perturbation_norms.append( torch.max(torch.abs(x - b_unnormalized)).item()) x.requires_grad = False #################################################### #################################################### # Save Poisons print(now(), "Saving poisons...") if not os.path.isdir(args.poisons_path): os.makedirs(args.poisons_path) with open(os.path.join(args.poisons_path, "poisons.pickle"), "wb") as handle: pickle.dump(poison_tuples, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "perturbation_norms.pickle"), "wb") as handle: pickle.dump(poison_perturbation_norms, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "base_indices.pickle"), "wb") as handle: pickle.dump(base_indices, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "target.pickle"), "wb") as handle: pickle.dump( (transforms.ToPILImage()(target_img), target_label), handle, protocol=pickle.HIGHEST_PROTOCOL, ) to_log_file("poisons saved.", args.output, craft_log) #################################################### print(now(), "craft_poisons_fc.py done.") return
def main(args): """Main function to train and test a model input: args: Argparse object that contains all the parsed values return: void """ print(args) print(now(), "train_model.py main() running.") np.random.seed(args.seed) torch.manual_seed(args.seed) train_log = "train_log.txt" to_log_file(args, args.output, train_log) # Set device device = "cuda" if torch.cuda.is_available() else "cpu" #################################################### # Load the Dataset if args.dataset.lower() == "cifar10": transform_train = get_transform(args.normalize, args.train_augment) transform_test = get_transform(args.normalize, False) trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train) trainset = PoisonedDataset(trainset, (), args.trainset_size, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) elif args.dataset.lower() == "cifar100": transform_train = get_transform(args.normalize, args.train_augment) transform_test = get_transform(args.normalize, False) trainset = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=transform_train) trainset = PoisonedDataset(trainset, (), args.trainset_size, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True) testset = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) elif args.dataset.lower() == "tinyimagenet_first": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="firsthalf", ) trainset = PoisonedDataset(trainset, (), args.trainset_size, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="firsthalf", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) elif args.dataset.lower() == "tinyimagenet_last": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="lasthalf", ) trainset = PoisonedDataset(trainset, (), args.trainset_size, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="lasthalf", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) elif args.dataset.lower() == "tinyimagenet_all": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="all", ) trainset = PoisonedDataset(trainset, (), args.trainset_size, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="all", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) else: print("Dataset not yet implemented. Ending run from train_model.py.") sys.exit() #################################################### #################################################### # Network and Optimizer net = get_model(args.model, args.dataset) net = net.to(device) start_epoch = 0 if args.optimizer == "SGD": optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=2e-4, momentum=0.9) elif args.optimizer == "adam": optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=2e-4) criterion = nn.CrossEntropyLoss() if args.model_path is not None: state_dict = torch.load(args.model_path, map_location=device) net.load_state_dict(state_dict["net"]) optimizer.load_state_dict(state_dict["optimizer"]) start_epoch = state_dict["epoch"] #################################################### #################################################### # Train and Test print("==> Training network...") loss = 0 all_losses = [] epoch = start_epoch print(f"Training for {args.epochs - start_epoch} epochs.") for epoch in range(start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_schedule, args.lr_factor) loss, acc = train(net, trainloader, optimizer, criterion, device) all_losses.append(loss) if (epoch + 1) % args.val_period == 0: natural_acc = test(net, testloader, device) print( now(), " Epoch: ", epoch, ", Loss: ", loss, ", Training acc: ", acc, ", Natural accuracy: ", natural_acc, ) to_log_file( { "epoch": epoch, "loss": loss, "training_acc": acc, "natural_acc": natural_acc, }, args.output, train_log, ) # test natural_acc = test(net, testloader, device) print(now(), " Training ended at epoch ", epoch, ", Natural accuracy: ", natural_acc) to_log_file( { "epoch": epoch, "loss": loss, "natural_acc": natural_acc }, args.output, train_log, ) #################################################### #################################################### # Save if args.save_net: state = { "net": net.state_dict(), "epoch": epoch, "optimizer": optimizer.state_dict(), } out_str = os.path.join( args.checkpoint, args.model + "_seed_" + str(args.seed) + "_normalize=" + str(args.normalize) + "_augment=" + str(args.train_augment) + "_optimizer=" + str(args.optimizer) + "_epoch=" + str(epoch) + ".pth", ) print("Saving model to: ", args.checkpoint, " out_str: ", out_str) if not os.path.isdir(args.checkpoint): os.makedirs(args.checkpoint) torch.save(state, out_str) #################################################### return