def main(args): """Main function to generate the CLBD poisons inputs: args: Argparse object reutrn: void """ print(now(), "craft_poisons_clbd.py main() running...") mean, std = data_mean_std_dict[args.dataset.lower()] mean = list(mean) std = list(std) normalize_net = NormalizeByChannelMeanStd(mean, std) device = "cuda" if torch.cuda.is_available() else "cpu" model = load_model_from_checkpoint(args.model[0], args.model_path[0], args.pretrain_dataset) model.eval() if args.normalize: model = nn.Sequential(normalize_net, model) model = model.to(device) #################################################### # Dataset if args.dataset.lower() == "cifar10": transform_test = get_transform(False, False) testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test) trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_test) elif args.dataset.lower() == "tinyimagenet_first": transform_test = get_transform(False, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="firsthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="firsthalf", ) elif args.dataset.lower() == "tinyimagenet_last": transform_test = get_transform(False, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="lasthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="lasthalf", ) elif args.dataset.lower() == "tinyimagenet_all": transform_test = get_transform(False, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="all", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="all", ) else: print( "Dataset not yet implemented. Exiting from craft_poisons_clbd.py.") sys.exit() ################################################### with open(args.poison_setups, "rb") as handle: setup_dicts = pickle.load(handle) setup = setup_dicts[args.setup_idx] target_img_idx = (setup["target index"] if args.target_img_idx is None else args.target_img_idx) base_indices = (setup["base indices"] if args.base_indices is None else args.base_indices) # get single target target_img, target_label = testset[target_img_idx] # get multiple bases base_imgs = torch.stack([trainset[i][0] for i in base_indices]).to(device) base_labels = torch.LongTensor([trainset[i][1] for i in base_indices]).to(device) # get attacker config = { "epsilon": args.epsilon, "step_size": args.step_size, "num_steps": args.num_steps, } attacker = AttackPGD(model, config) # get patch trans_trigger = transforms.Compose([ transforms.Resize((args.patch_size, args.patch_size)), transforms.ToTensor() ]) trigger = Image.open("./poison_crafting/triggers/clbd.png").convert("RGB") trigger = trans_trigger(trigger).unsqueeze(0).to(device) # craft poisons num_batches = int(np.ceil(base_imgs.shape[0] / 1000)) batches = [(base_imgs[1000 * i:1000 * (i + 1)], base_labels[1000 * i:1000 * (i + 1)]) for i in range(num_batches)] # attack all the bases adv_batches = [] for batch_img, batch_labels in batches: adv_batches.append(attacker(batch_img, batch_labels)) adv_bases = torch.cat(adv_batches) # Starting coordinates of the patch start_x = args.image_size - args.patch_size start_y = args.image_size - args.patch_size # Mask mask = torch.ones_like(adv_bases) # uncomment for patching all corners mask[:, start_y:start_y + args.patch_size, start_x:start_x + args.patch_size] = 0 # mask[:, 0 : args.patch_size, start_x : start_x + args.patch_size] = 0 # mask[:, start_y : start_y + args.patch_size, 0 : args.patch_size] = 0 # mask[:, 0 : args.patch_size, 0 : args.patch_size] = 0 pert = (adv_bases - base_imgs) * mask adv_bases_masked = base_imgs + pert # Attching patch to the masks for i in range(len(base_imgs)): # uncomment for patching all corners adv_bases_masked[i, :, start_y:start_y + args.patch_size, start_x:start_x + args.patch_size, ] = trigger # adv_bases_masked[ # i, :, 0 : args.patch_size, start_x : start_x + args.patch_size # ] = trigger # adv_bases_masked[ # i, :, start_y : start_y + args.patch_size, 0 : args.patch_size # ] = torch.flip(trigger, (-1,)) # adv_bases_masked[i, :, 0 : args.patch_size, 0 : args.patch_size] = torch.flip( # trigger, (-1,) # ) final_pert = torch.clamp(adv_bases_masked - base_imgs, -args.epsilon, args.epsilon) poisons = base_imgs + final_pert poisons = poisons.clamp(0, 1).cpu() poisoned_tuples = [(transforms.ToPILImage()(poisons[i]), base_labels[i].item()) for i in range(poisons.shape[0])] target_tuple = ( transforms.ToPILImage()(target_img), int(target_label), trigger.squeeze(0).cpu(), [start_x, start_y], ) #################################################### # Save Poisons print(now(), "Saving poisons...") if not os.path.isdir(args.poisons_path): os.makedirs(args.poisons_path) with open(os.path.join(args.poisons_path, "poisons.pickle"), "wb") as handle: pickle.dump(poisoned_tuples, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "target.pickle"), "wb") as handle: pickle.dump( target_tuple, handle, protocol=pickle.HIGHEST_PROTOCOL, ) with open(os.path.join(args.poisons_path, "base_indices.pickle"), "wb") as handle: pickle.dump(base_indices, handle, protocol=pickle.HIGHEST_PROTOCOL) #################################################### print(now(), "craft_poisons_clbd.py done.") return
def main(args): """Main function to test a model input: args: Argparse object that contains all the parsed values return: void """ print(now(), "test_model.py main() running.") test_log = "clean_test_log.txt" to_log_file(args, args.output, test_log) # Set device device = "cuda" if torch.cuda.is_available() else "cpu" #################################################### # Dataset if args.dataset.lower() == "cifar10": transform_train = get_transform(args.normalize, args.train_augment) transform_test = get_transform(args.normalize, False) trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128) testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) elif args.dataset.lower() == "cifar100": transform_train = get_transform(args.normalize, args.train_augment) transform_test = get_transform(args.normalize, False) trainset = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128) testset = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False) elif args.dataset.lower() == "tinyimagenet_first": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="firsthalf", ) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="firsthalf", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) elif args.dataset.lower() == "tinyimagenet_last": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="lasthalf", ) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="lasthalf", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) elif args.dataset.lower() == "tinyimagenet_all": transform_train = get_transform(args.normalize, args.train_augment, dataset=args.dataset) transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_train, classes="all", ) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, num_workers=1, shuffle=True) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="all", ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, num_workers=1, shuffle=False) else: print("Dataset not yet implemented. Exiting from test_model.py.") sys.exit() #################################################### #################################################### # Network and Optimizer net = get_model(args.model, args.dataset) # load model from path if a path is provided if args.model_path is not None: net = load_model_from_checkpoint(args.model, args.model_path, args.dataset) else: print( "No model path provided, continuing test with untrained network.") net = net.to(device) #################################################### #################################################### # Test Model training_acc = test(net, trainloader, device) natural_acc = test(net, testloader, device) print(now(), " Training accuracy: ", training_acc) print(now(), " Natural accuracy: ", natural_acc) stats = OrderedDict([ ("model path", args.model_path), ("model", args.model), ("normalize", args.normalize), ("augment", args.train_augment), ("training_acc", training_acc), ("natural_acc", natural_acc), ]) to_results_table(stats, args.output, "clean_performance.csv") #################################################### return
def main(args): """Main function to generate the CP poisons inputs: args: Argparse object reutrn: void """ print(now(), "craft_poisons_bp.py main() running.") craft_log = "craft_log.txt" to_log_file(args, args.output, craft_log) #################################################### # Dataset if args.dataset.lower() == "cifar10": transform_test = get_transform(args.normalize, False) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform_test ) trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform_test ) elif args.dataset.lower() == "tinyimagenet_first": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="firsthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="firsthalf", ) elif args.dataset.lower() == "tinyimagenet_last": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="lasthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="lasthalf", ) elif args.dataset.lower() == "tinyimagenet_all": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="all", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="all", ) else: print("Dataset not yet implemented. Exiting from craft_poisons_cp.py.") sys.exit() ################################################### #################################################### # Find target and base images with open(args.poison_setups, "rb") as handle: setup_dicts = pickle.load(handle) setup = setup_dicts[args.setup_idx] target_img_idx = ( setup["target index"] if args.target_img_idx is None else args.target_img_idx ) base_indices = ( setup["base indices"] if args.base_indices is None else args.base_indices ) # get single target target_img, target_label = testset[target_img_idx] # get multiple bases base_imgs = torch.stack([trainset[i][0] for i in base_indices]) base_labels = [trainset[i][1] for i in base_indices] poisoned_label = base_labels[0] # log target and base details to_log_file("base indices: " + str(base_indices), args.output, craft_log) to_log_file("base labels: " + str(base_labels), args.output, craft_log) to_log_file("target_label: " + str(target_label), args.output, craft_log) to_log_file("target_index: " + str(target_img_idx), args.output, craft_log) # Set visible CUDA devices cudnn.benchmark = True # load the pre-trained models sub_net_list = [] for n_model, chk_name in enumerate(args.model_path): net = load_model_from_checkpoint( args.model[n_model], chk_name, args.pretrain_dataset ) sub_net_list.append(net) target_net = load_model_from_checkpoint( args.target_model, args.target_model_path, args.pretrain_dataset ) # Get the target image target = target_img.unsqueeze(0) chk_path = args.poisons_path if not os.path.exists(chk_path): os.makedirs(chk_path) base_tensor_list = [base_imgs[i] for i in range(base_imgs.shape[0])] base_tensor_list = [bt.to("cuda") for bt in base_tensor_list] poison_init = base_tensor_list mean, std = data_mean_std_dict[args.dataset.lower()] poison_tuple_list = make_convex_polytope_poisons( sub_net_list, target_net, base_tensor_list, target, "cuda", opt_method=args.poison_opt, lr=args.poison_lr, momentum=args.poison_momentum, iterations=args.crafting_iters, epsilon=args.epsilon, decay_ites=args.poison_decay_ites, decay_ratio=args.poison_decay_ratio, mean=torch.Tensor(mean).reshape(1, 3, 1, 1), std=torch.Tensor(std).reshape(1, 3, 1, 1), chk_path=chk_path, poison_idxes=base_indices, poison_label=poisoned_label, tol=args.tol, start_ite=0, poison_init=poison_init, end2end=args.end2end, mode="mean", ) # move poisons to PIL format if args.normalize: target = un_normalize_data(target.squeeze(0), args.dataset) for i in range(len(poison_tuple_list)): poison_tuple_list[i] = ( transforms.ToPILImage()( un_normalize_data(poison_tuple_list[i][0], args.dataset) ), poison_tuple_list[i][1], ) else: target = target.squeeze(0) for i in range(len(poison_tuple_list)): poison_tuple_list[i] = ( transforms.ToPILImage()(poison_tuple_list[i][0]), poison_tuple_list[i][1], ) # get perturbation norms poison_perturbation_norms = [] for idx, (poison_tensor, p_label) in enumerate(poison_tuple_list): poison_perturbation_norms.append( torch.max( torch.abs( transforms.ToTensor()(poison_tensor) - un_normalize_data(base_tensor_list[idx].cpu(), args.dataset) ) ).item() ) to_log_file( "perturbation norms: " + str(poison_perturbation_norms), args.output, "craft_log.txt", ) #################################################### # Save Poisons print(now(), "Saving poisons...") if not os.path.isdir(args.poisons_path): os.makedirs(args.poisons_path) with open(os.path.join(args.poisons_path, "poisons.pickle"), "wb") as handle: pickle.dump(poison_tuple_list, handle, protocol=pickle.HIGHEST_PROTOCOL) with open( os.path.join(args.poisons_path, "perturbation_norms.pickle"), "wb" ) as handle: pickle.dump(poison_perturbation_norms, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "base_indices.pickle"), "wb") as handle: pickle.dump(base_indices, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "target.pickle"), "wb") as handle: pickle.dump( (transforms.ToPILImage()(target), target_label), handle, protocol=pickle.HIGHEST_PROTOCOL, ) to_log_file("poisons saved.", args.output, "craft_log.txt") #################################################### print(now(), "craft_poisons_bp.py done.") return
def main(args): """Main function to check the success rate of the given poisons input: args: Argparse object return: void """ print(now(), "poison_test.py main() running.") test_log = "poison_test_log.txt" to_log_file(args, args.output, test_log) lr = args.lr # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # load the poisons and their indices within the training set from pickled files with open(os.path.join(args.poisons_path, "poisons.pickle"), "rb") as handle: poison_tuples = pickle.load(handle) print(len(poison_tuples), " poisons in this trial.") poisoned_label = poison_tuples[0][1] with open(os.path.join(args.poisons_path, "base_indices.pickle"), "rb") as handle: poison_indices = pickle.load(handle) # get the dataset and the dataloaders trainloader, testloader, dataset, transform_train, transform_test, num_classes = \ get_dataset(args, poison_tuples, poison_indices) # get the target image from pickled file with open(os.path.join(args.poisons_path, "target.pickle"), "rb") as handle: target_img_tuple = pickle.load(handle) target_class = target_img_tuple[1] if len(target_img_tuple) == 4: patch = target_img_tuple[2] if torch.is_tensor(target_img_tuple[2]) else \ torch.tensor(target_img_tuple[2]) if patch.shape[0] != 3 or patch.shape[1] != args.patch_size or \ patch.shape[2] != args.patch_size: print( f"Expected shape of the patch is [3, {args.patch_size}, {args.patch_size}] " f"but is {patch.shape}. Exiting from poison_test.py.") sys.exit() startx, starty = target_img_tuple[3] target_img_pil = target_img_tuple[0] h, w = target_img_pil.size if starty + args.patch_size > h or startx + args.patch_size > w: print( "Invalid startx or starty point for the patch. Exiting from poison_test.py." ) sys.exit() target_img_tensor = transforms.ToTensor()(target_img_pil) target_img_tensor[:, starty:starty + args.patch_size, startx:startx + args.patch_size] = patch target_img_pil = transforms.ToPILImage()(target_img_tensor) else: target_img_pil = target_img_tuple[0] target_img = transform_test(target_img_pil) poison_perturbation_norms = compute_perturbation_norms( poison_tuples, dataset, poison_indices) # the limit is '8/255' but we assert that it is smaller than 9/255 to account for PIL # truncation. assert max( poison_perturbation_norms) - 9 / 255 < 1e-5, "Attack not clean label!" #################################################### #################################################### # Network and Optimizer # load model from path if a path is provided if args.model_path is not None: net = load_model_from_checkpoint(args.model, args.model_path, args.pretrain_dataset) else: args.ffe = False # we wouldn't fine tune from a random intiialization net = get_model(args.model, args.dataset) # freeze weights in feature extractor if not doing from scratch retraining if args.ffe: for param in net.parameters(): param.requires_grad = False # reinitialize the linear layer num_ftrs = net.linear.in_features net.linear = nn.Linear(num_ftrs, num_classes) # requires grad by default # set optimizer if args.optimizer.upper() == "SGD": optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=args.weight_decay, momentum=0.9) elif args.optimizer.upper() == "ADAM": optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss() #################################################### #################################################### # Poison and Train and Test print("==> Training network...") epoch = 0 for epoch in range(args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_schedule, args.lr_factor) loss, acc = train(net, trainloader, optimizer, criterion, device, train_bn=not args.ffe) if (epoch + 1) % args.val_period == 0: natural_acc = test(net, testloader, device) net.eval() p_acc = (net(target_img.unsqueeze(0).to(device)).max(1)[1].item() == poisoned_label) print( now(), " Epoch: ", epoch, ", Loss: ", loss, ", Training acc: ", acc, ", Natural accuracy: ", natural_acc, ", poison success: ", p_acc, ) # test natural_acc = test(net, testloader, device) print(now(), " Training ended at epoch ", epoch, ", Natural accuracy: ", natural_acc) net.eval() p_acc = net( target_img.unsqueeze(0).to(device)).max(1)[1].item() == poisoned_label print( now(), " poison success: ", p_acc, " poisoned_label: ", poisoned_label, " prediction: ", net(target_img.unsqueeze(0).to(device)).max(1)[1].item(), ) # Dictionary to write contest the csv file stats = OrderedDict([ ("poisons path", args.poisons_path), ("model", args.model_path if args.model_path is not None else args.model), ("target class", target_class), ("base class", poisoned_label), ("num poisons", len(poison_tuples)), ("max perturbation norm", np.max(poison_perturbation_norms)), ("epoch", epoch), ("loss", loss), ("training_acc", acc), ("natural_acc", natural_acc), ("poison_acc", p_acc), ]) to_results_table(stats, args.output) #################################################### return
def main(args): """Main function to generate the FC poisons inputs: args: Argparse object reutrn: void """ print(now(), "craft_poisons_fc.py main() running.") craft_log = "craft_log.txt" to_log_file(args, args.output, craft_log) # Set device device = "cuda" if torch.cuda.is_available() else "cpu" #################################################### # Dataset if args.dataset.lower() == "cifar10": transform_test = get_transform(args.normalize, False) trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_test) testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test) elif args.dataset.lower() == "tinyimagenet_first": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="firsthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="firsthalf", ) elif args.dataset.lower() == "tinyimagenet_last": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="lasthalf", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="lasthalf", ) elif args.dataset.lower() == "tinyimagenet_all": transform_test = get_transform(args.normalize, False, dataset=args.dataset) trainset = TinyImageNet( TINYIMAGENET_ROOT, split="train", transform=transform_test, classes="all", ) testset = TinyImageNet( TINYIMAGENET_ROOT, split="val", transform=transform_test, classes="all", ) else: print("Dataset not yet implemented. Exiting from craft_poisons_fc.py.") sys.exit() ################################################### #################################################### # Craft and insert poison image feature_extractors = [] for i in range(len(args.model)): feature_extractors.append( load_model_from_checkpoint(args.model[i], args.model_path[i], args.pretrain_dataset)) for i in range(len(feature_extractors)): for param in feature_extractors[i].parameters(): param.requires_grad = False feature_extractors[i].eval() feature_extractors[i] = feature_extractors[i].to(device) with open(args.poison_setups, "rb") as handle: setup_dicts = pickle.load(handle) setup = setup_dicts[args.setup_idx] target_img_idx = (setup["target index"] if args.target_img_idx is None else args.target_img_idx) base_indices = (setup["base indices"] if args.base_indices is None else args.base_indices) # Craft poisons poison_iterations = args.crafting_iters poison_perturbation_norms = [] # get single target target_img, target_label = testset[target_img_idx] # get multiple bases base_imgs = torch.stack([trainset[i][0] for i in base_indices]) base_labels = torch.LongTensor([trainset[i][1] for i in base_indices]) # log target and base details to_log_file("base indices: " + str(base_indices), args.output, craft_log) to_log_file("base labels: " + str(base_labels), args.output, craft_log) to_log_file("target_label: " + str(target_label), args.output, craft_log) to_log_file("target_index: " + str(target_img_idx), args.output, craft_log) # fill list of tuples of poison images and labels poison_tuples = [] target_img = (un_normalize_data(target_img, args.dataset) if args.normalize else target_img) beta = 4.0 if args.normalize else 0.1 base_tuples = list(zip(base_imgs, base_labels)) for base_img, label in base_tuples: # unnormalize the images for optimization b_unnormalized = (un_normalize_data(base_img, args.dataset) if args.normalize else base_img) objective_vals = [10e8] step_size = args.step_size # watermarking x = copy.deepcopy(b_unnormalized) x = args.watermark_coeff * target_img + (1 - args.watermark_coeff) * x # feature collision optimization done_with_fc = False i = 0 while not done_with_fc and i < poison_iterations: x.requires_grad = True if args.normalize: mini_batch = torch.stack([ normalize_data(x, args.dataset), normalize_data(target_img, args.dataset), ]).to(device) else: mini_batch = torch.stack([x, target_img]).to(device) loss = 0 for feature_extractor in feature_extractors: feats = feature_extractor.penultimate(mini_batch) loss += torch.norm(feats[0, :] - feats[1, :])**2 grad = torch.autograd.grad(loss, [x])[0] x_hat = x.detach() - step_size * grad.detach() if not args.l2: pert = (x_hat - b_unnormalized).clamp(-args.epsilon, args.epsilon) x_new = b_unnormalized + pert x_new = x_new.clamp(0, 1) obj = loss else: x_new = (x_hat.detach() + step_size * beta * b_unnormalized.detach()) / (1 + step_size * beta) x_new = x_new.clamp(0, 1) obj = beta * torch.norm(x_new - b_unnormalized)**2 + loss if obj > objective_vals[-1]: step_size *= 0.2 else: if torch.norm(x - x_new) / torch.norm(x) < 1e-5: done_with_fc = True x = copy.deepcopy(x_new) objective_vals.append(obj) i += 1 poison_tuples.append((transforms.ToPILImage()(x), label.item())) poison_perturbation_norms.append( torch.max(torch.abs(x - b_unnormalized)).item()) x.requires_grad = False #################################################### #################################################### # Save Poisons print(now(), "Saving poisons...") if not os.path.isdir(args.poisons_path): os.makedirs(args.poisons_path) with open(os.path.join(args.poisons_path, "poisons.pickle"), "wb") as handle: pickle.dump(poison_tuples, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "perturbation_norms.pickle"), "wb") as handle: pickle.dump(poison_perturbation_norms, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "base_indices.pickle"), "wb") as handle: pickle.dump(base_indices, handle, protocol=pickle.HIGHEST_PROTOCOL) with open(os.path.join(args.poisons_path, "target.pickle"), "wb") as handle: pickle.dump( (transforms.ToPILImage()(target_img), target_label), handle, protocol=pickle.HIGHEST_PROTOCOL, ) to_log_file("poisons saved.", args.output, craft_log) #################################################### print(now(), "craft_poisons_fc.py done.") return