Пример #1
0
def main(args):
    """Main function to generate the CLBD poisons
    inputs:
        args:           Argparse object
    reutrn:
        void
    """
    print(now(), "craft_poisons_clbd.py main() running...")
    mean, std = data_mean_std_dict[args.dataset.lower()]
    mean = list(mean)
    std = list(std)
    normalize_net = NormalizeByChannelMeanStd(mean, std)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = load_model_from_checkpoint(args.model[0], args.model_path[0],
                                       args.pretrain_dataset)
    model.eval()
    if args.normalize:
        model = nn.Sequential(normalize_net, model)
    model = model.to(device)

    ####################################################
    #               Dataset
    if args.dataset.lower() == "cifar10":
        transform_test = get_transform(False, False)
        testset = torchvision.datasets.CIFAR10(root="./data",
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        trainset = torchvision.datasets.CIFAR10(root="./data",
                                                train=True,
                                                download=True,
                                                transform=transform_test)
    elif args.dataset.lower() == "tinyimagenet_first":
        transform_test = get_transform(False, False, dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="firsthalf",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="firsthalf",
        )
    elif args.dataset.lower() == "tinyimagenet_last":
        transform_test = get_transform(False, False, dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="lasthalf",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="lasthalf",
        )
    elif args.dataset.lower() == "tinyimagenet_all":
        transform_test = get_transform(False, False, dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="all",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="all",
        )
    else:
        print(
            "Dataset not yet implemented. Exiting from craft_poisons_clbd.py.")
        sys.exit()
    ###################################################

    with open(args.poison_setups, "rb") as handle:
        setup_dicts = pickle.load(handle)
    setup = setup_dicts[args.setup_idx]

    target_img_idx = (setup["target index"]
                      if args.target_img_idx is None else args.target_img_idx)
    base_indices = (setup["base indices"]
                    if args.base_indices is None else args.base_indices)

    # get single target
    target_img, target_label = testset[target_img_idx]

    # get multiple bases
    base_imgs = torch.stack([trainset[i][0] for i in base_indices]).to(device)
    base_labels = torch.LongTensor([trainset[i][1]
                                    for i in base_indices]).to(device)

    # get attacker
    config = {
        "epsilon": args.epsilon,
        "step_size": args.step_size,
        "num_steps": args.num_steps,
    }
    attacker = AttackPGD(model, config)

    # get patch
    trans_trigger = transforms.Compose([
        transforms.Resize((args.patch_size, args.patch_size)),
        transforms.ToTensor()
    ])
    trigger = Image.open("./poison_crafting/triggers/clbd.png").convert("RGB")
    trigger = trans_trigger(trigger).unsqueeze(0).to(device)

    # craft poisons
    num_batches = int(np.ceil(base_imgs.shape[0] / 1000))
    batches = [(base_imgs[1000 * i:1000 * (i + 1)],
                base_labels[1000 * i:1000 * (i + 1)])
               for i in range(num_batches)]

    # attack all the bases
    adv_batches = []
    for batch_img, batch_labels in batches:
        adv_batches.append(attacker(batch_img, batch_labels))
    adv_bases = torch.cat(adv_batches)

    # Starting coordinates of the patch
    start_x = args.image_size - args.patch_size
    start_y = args.image_size - args.patch_size

    # Mask
    mask = torch.ones_like(adv_bases)

    # uncomment for patching all corners
    mask[:, start_y:start_y + args.patch_size,
         start_x:start_x + args.patch_size] = 0
    # mask[:, 0 : args.patch_size, start_x : start_x + args.patch_size] = 0
    # mask[:, start_y : start_y + args.patch_size, 0 : args.patch_size] = 0
    # mask[:, 0 : args.patch_size, 0 : args.patch_size] = 0

    pert = (adv_bases - base_imgs) * mask
    adv_bases_masked = base_imgs + pert

    # Attching patch to the masks
    for i in range(len(base_imgs)):
        # uncomment for patching all corners
        adv_bases_masked[i, :, start_y:start_y + args.patch_size,
                         start_x:start_x + args.patch_size, ] = trigger
        # adv_bases_masked[
        #     i, :, 0 : args.patch_size, start_x : start_x + args.patch_size
        # ] = trigger
        # adv_bases_masked[
        #     i, :, start_y : start_y + args.patch_size, 0 : args.patch_size
        # ] = torch.flip(trigger, (-1,))
        # adv_bases_masked[i, :, 0 : args.patch_size, 0 : args.patch_size] = torch.flip(
        #     trigger, (-1,)
        # )

    final_pert = torch.clamp(adv_bases_masked - base_imgs, -args.epsilon,
                             args.epsilon)
    poisons = base_imgs + final_pert

    poisons = poisons.clamp(0, 1).cpu()
    poisoned_tuples = [(transforms.ToPILImage()(poisons[i]),
                        base_labels[i].item())
                       for i in range(poisons.shape[0])]

    target_tuple = (
        transforms.ToPILImage()(target_img),
        int(target_label),
        trigger.squeeze(0).cpu(),
        [start_x, start_y],
    )

    ####################################################
    #        Save Poisons
    print(now(), "Saving poisons...")
    if not os.path.isdir(args.poisons_path):
        os.makedirs(args.poisons_path)
    with open(os.path.join(args.poisons_path, "poisons.pickle"),
              "wb") as handle:
        pickle.dump(poisoned_tuples, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args.poisons_path, "target.pickle"),
              "wb") as handle:
        pickle.dump(
            target_tuple,
            handle,
            protocol=pickle.HIGHEST_PROTOCOL,
        )
    with open(os.path.join(args.poisons_path, "base_indices.pickle"),
              "wb") as handle:
        pickle.dump(base_indices, handle, protocol=pickle.HIGHEST_PROTOCOL)
    ####################################################

    print(now(), "craft_poisons_clbd.py done.")
    return
Пример #2
0
def main(args):
    """Main function to test a model
    input:
        args:       Argparse object that contains all the parsed values
    return:
        void
    """

    print(now(), "test_model.py main() running.")

    test_log = "clean_test_log.txt"
    to_log_file(args, args.output, test_log)

    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ####################################################
    #               Dataset
    if args.dataset.lower() == "cifar10":
        transform_train = get_transform(args.normalize, args.train_augment)
        transform_test = get_transform(args.normalize, False)
        trainset = torchvision.datasets.CIFAR10(root="./data",
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=128)
        testset = torchvision.datasets.CIFAR10(root="./data",
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=128,
                                                 shuffle=False)
    elif args.dataset.lower() == "cifar100":
        transform_train = get_transform(args.normalize, args.train_augment)
        transform_test = get_transform(args.normalize, False)
        trainset = torchvision.datasets.CIFAR100(root="./data",
                                                 train=True,
                                                 download=True,
                                                 transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=128)
        testset = torchvision.datasets.CIFAR100(root="./data",
                                                train=False,
                                                download=True,
                                                transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=128,
                                                 shuffle=False)

    elif args.dataset.lower() == "tinyimagenet_first":
        transform_train = get_transform(args.normalize,
                                        args.train_augment,
                                        dataset=args.dataset)
        transform_test = get_transform(args.normalize,
                                       False,
                                       dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_train,
            classes="firsthalf",
        )
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=64,
                                                  num_workers=1,
                                                  shuffle=True)
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="firsthalf",
        )
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=64,
                                                 num_workers=1,
                                                 shuffle=False)

    elif args.dataset.lower() == "tinyimagenet_last":
        transform_train = get_transform(args.normalize,
                                        args.train_augment,
                                        dataset=args.dataset)
        transform_test = get_transform(args.normalize,
                                       False,
                                       dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_train,
            classes="lasthalf",
        )
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=64,
                                                  num_workers=1,
                                                  shuffle=True)
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="lasthalf",
        )
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=64,
                                                 num_workers=1,
                                                 shuffle=False)

    elif args.dataset.lower() == "tinyimagenet_all":
        transform_train = get_transform(args.normalize,
                                        args.train_augment,
                                        dataset=args.dataset)
        transform_test = get_transform(args.normalize,
                                       False,
                                       dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_train,
            classes="all",
        )
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=64,
                                                  num_workers=1,
                                                  shuffle=True)
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="all",
        )
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=64,
                                                 num_workers=1,
                                                 shuffle=False)

    else:
        print("Dataset not yet implemented. Exiting from test_model.py.")
        sys.exit()

    ####################################################

    ####################################################
    #           Network and Optimizer
    net = get_model(args.model, args.dataset)

    # load model from path if a path is provided
    if args.model_path is not None:
        net = load_model_from_checkpoint(args.model, args.model_path,
                                         args.dataset)
    else:
        print(
            "No model path provided, continuing test with untrained network.")
    net = net.to(device)
    ####################################################

    ####################################################
    #        Test Model
    training_acc = test(net, trainloader, device)
    natural_acc = test(net, testloader, device)
    print(now(), " Training accuracy: ", training_acc)
    print(now(), " Natural accuracy: ", natural_acc)
    stats = OrderedDict([
        ("model path", args.model_path),
        ("model", args.model),
        ("normalize", args.normalize),
        ("augment", args.train_augment),
        ("training_acc", training_acc),
        ("natural_acc", natural_acc),
    ])
    to_results_table(stats, args.output, "clean_performance.csv")
    ####################################################

    return
def main(args):
    """Main function to generate the CP poisons
    inputs:
        args:           Argparse object
    reutrn:
        void
    """
    print(now(), "craft_poisons_bp.py main() running.")

    craft_log = "craft_log.txt"
    to_log_file(args, args.output, craft_log)

    ####################################################
    #               Dataset
    if args.dataset.lower() == "cifar10":
        transform_test = get_transform(args.normalize, False)
        testset = torchvision.datasets.CIFAR10(
            root="./data", train=False, download=True, transform=transform_test
        )
        trainset = torchvision.datasets.CIFAR10(
            root="./data", train=True, download=True, transform=transform_test
        )
    elif args.dataset.lower() == "tinyimagenet_first":
        transform_test = get_transform(args.normalize, False, dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="firsthalf",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="firsthalf",
        )
    elif args.dataset.lower() == "tinyimagenet_last":
        transform_test = get_transform(args.normalize, False, dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="lasthalf",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="lasthalf",
        )
    elif args.dataset.lower() == "tinyimagenet_all":
        transform_test = get_transform(args.normalize, False, dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="all",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="all",
        )
    else:
        print("Dataset not yet implemented. Exiting from craft_poisons_cp.py.")
        sys.exit()
    ###################################################

    ####################################################
    #         Find target and base images
    with open(args.poison_setups, "rb") as handle:
        setup_dicts = pickle.load(handle)
    setup = setup_dicts[args.setup_idx]

    target_img_idx = (
        setup["target index"] if args.target_img_idx is None else args.target_img_idx
    )
    base_indices = (
        setup["base indices"] if args.base_indices is None else args.base_indices
    )

    # get single target
    target_img, target_label = testset[target_img_idx]

    # get multiple bases
    base_imgs = torch.stack([trainset[i][0] for i in base_indices])
    base_labels = [trainset[i][1] for i in base_indices]
    poisoned_label = base_labels[0]

    # log target and base details
    to_log_file("base indices: " + str(base_indices), args.output, craft_log)
    to_log_file("base labels: " + str(base_labels), args.output, craft_log)
    to_log_file("target_label: " + str(target_label), args.output, craft_log)
    to_log_file("target_index: " + str(target_img_idx), args.output, craft_log)

    # Set visible CUDA devices
    cudnn.benchmark = True

    # load the pre-trained models
    sub_net_list = []
    for n_model, chk_name in enumerate(args.model_path):
        net = load_model_from_checkpoint(
            args.model[n_model], chk_name, args.pretrain_dataset
        )
        sub_net_list.append(net)

    target_net = load_model_from_checkpoint(
        args.target_model, args.target_model_path, args.pretrain_dataset
    )

    # Get the target image
    target = target_img.unsqueeze(0)

    chk_path = args.poisons_path
    if not os.path.exists(chk_path):
        os.makedirs(chk_path)

    base_tensor_list = [base_imgs[i] for i in range(base_imgs.shape[0])]
    base_tensor_list = [bt.to("cuda") for bt in base_tensor_list]

    poison_init = base_tensor_list
    mean, std = data_mean_std_dict[args.dataset.lower()]
    poison_tuple_list = make_convex_polytope_poisons(
        sub_net_list,
        target_net,
        base_tensor_list,
        target,
        "cuda",
        opt_method=args.poison_opt,
        lr=args.poison_lr,
        momentum=args.poison_momentum,
        iterations=args.crafting_iters,
        epsilon=args.epsilon,
        decay_ites=args.poison_decay_ites,
        decay_ratio=args.poison_decay_ratio,
        mean=torch.Tensor(mean).reshape(1, 3, 1, 1),
        std=torch.Tensor(std).reshape(1, 3, 1, 1),
        chk_path=chk_path,
        poison_idxes=base_indices,
        poison_label=poisoned_label,
        tol=args.tol,
        start_ite=0,
        poison_init=poison_init,
        end2end=args.end2end,
        mode="mean",
    )

    # move poisons to PIL format
    if args.normalize:
        target = un_normalize_data(target.squeeze(0), args.dataset)
        for i in range(len(poison_tuple_list)):
            poison_tuple_list[i] = (
                transforms.ToPILImage()(
                    un_normalize_data(poison_tuple_list[i][0], args.dataset)
                ),
                poison_tuple_list[i][1],
            )
    else:
        target = target.squeeze(0)
        for i in range(len(poison_tuple_list)):
            poison_tuple_list[i] = (
                transforms.ToPILImage()(poison_tuple_list[i][0]),
                poison_tuple_list[i][1],
            )

    # get perturbation norms
    poison_perturbation_norms = []
    for idx, (poison_tensor, p_label) in enumerate(poison_tuple_list):
        poison_perturbation_norms.append(
            torch.max(
                torch.abs(
                    transforms.ToTensor()(poison_tensor)
                    - un_normalize_data(base_tensor_list[idx].cpu(), args.dataset)
                )
            ).item()
        )
    to_log_file(
        "perturbation norms: " + str(poison_perturbation_norms),
        args.output,
        "craft_log.txt",
    )

    ####################################################
    #        Save Poisons
    print(now(), "Saving poisons...")
    if not os.path.isdir(args.poisons_path):
        os.makedirs(args.poisons_path)
    with open(os.path.join(args.poisons_path, "poisons.pickle"), "wb") as handle:
        pickle.dump(poison_tuple_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(
        os.path.join(args.poisons_path, "perturbation_norms.pickle"), "wb"
    ) as handle:
        pickle.dump(poison_perturbation_norms, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args.poisons_path, "base_indices.pickle"), "wb") as handle:
        pickle.dump(base_indices, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args.poisons_path, "target.pickle"), "wb") as handle:
        pickle.dump(
            (transforms.ToPILImage()(target), target_label),
            handle,
            protocol=pickle.HIGHEST_PROTOCOL,
        )
    to_log_file("poisons saved.", args.output, "craft_log.txt")
    ####################################################

    print(now(), "craft_poisons_bp.py done.")
    return
Пример #4
0
def main(args):
    """Main function to check the success rate of the given poisons
    input:
        args:       Argparse object
    return:
        void
    """
    print(now(), "poison_test.py main() running.")

    test_log = "poison_test_log.txt"
    to_log_file(args, args.output, test_log)

    lr = args.lr

    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # load the poisons and their indices within the training set from pickled files
    with open(os.path.join(args.poisons_path, "poisons.pickle"),
              "rb") as handle:
        poison_tuples = pickle.load(handle)
        print(len(poison_tuples), " poisons in this trial.")
        poisoned_label = poison_tuples[0][1]
    with open(os.path.join(args.poisons_path, "base_indices.pickle"),
              "rb") as handle:
        poison_indices = pickle.load(handle)

    # get the dataset and the dataloaders
    trainloader, testloader, dataset, transform_train, transform_test, num_classes = \
        get_dataset(args, poison_tuples, poison_indices)

    # get the target image from pickled file
    with open(os.path.join(args.poisons_path, "target.pickle"),
              "rb") as handle:
        target_img_tuple = pickle.load(handle)
        target_class = target_img_tuple[1]
        if len(target_img_tuple) == 4:
            patch = target_img_tuple[2] if torch.is_tensor(target_img_tuple[2]) else \
                torch.tensor(target_img_tuple[2])
            if patch.shape[0] != 3 or patch.shape[1] != args.patch_size or \
                    patch.shape[2] != args.patch_size:
                print(
                    f"Expected shape of the patch is [3, {args.patch_size}, {args.patch_size}] "
                    f"but is {patch.shape}. Exiting from poison_test.py.")
                sys.exit()

            startx, starty = target_img_tuple[3]
            target_img_pil = target_img_tuple[0]
            h, w = target_img_pil.size

            if starty + args.patch_size > h or startx + args.patch_size > w:
                print(
                    "Invalid startx or starty point for the patch. Exiting from poison_test.py."
                )
                sys.exit()

            target_img_tensor = transforms.ToTensor()(target_img_pil)
            target_img_tensor[:, starty:starty + args.patch_size,
                              startx:startx + args.patch_size] = patch
            target_img_pil = transforms.ToPILImage()(target_img_tensor)

        else:
            target_img_pil = target_img_tuple[0]

        target_img = transform_test(target_img_pil)

    poison_perturbation_norms = compute_perturbation_norms(
        poison_tuples, dataset, poison_indices)

    # the limit is '8/255' but we assert that it is smaller than 9/255 to account for PIL
    # truncation.
    assert max(
        poison_perturbation_norms) - 9 / 255 < 1e-5, "Attack not clean label!"
    ####################################################

    ####################################################
    #           Network and Optimizer

    # load model from path if a path is provided
    if args.model_path is not None:
        net = load_model_from_checkpoint(args.model, args.model_path,
                                         args.pretrain_dataset)
    else:
        args.ffe = False  # we wouldn't fine tune from a random intiialization
        net = get_model(args.model, args.dataset)

    # freeze weights in feature extractor if not doing from scratch retraining
    if args.ffe:
        for param in net.parameters():
            param.requires_grad = False

    # reinitialize the linear layer
    num_ftrs = net.linear.in_features
    net.linear = nn.Linear(num_ftrs, num_classes)  # requires grad by default

    # set optimizer
    if args.optimizer.upper() == "SGD":
        optimizer = optim.SGD(net.parameters(),
                              lr=lr,
                              weight_decay=args.weight_decay,
                              momentum=0.9)
    elif args.optimizer.upper() == "ADAM":
        optimizer = optim.Adam(net.parameters(),
                               lr=lr,
                               weight_decay=args.weight_decay)
    criterion = nn.CrossEntropyLoss()
    ####################################################

    ####################################################
    #        Poison and Train and Test
    print("==> Training network...")
    epoch = 0
    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch, args.lr_schedule,
                             args.lr_factor)
        loss, acc = train(net,
                          trainloader,
                          optimizer,
                          criterion,
                          device,
                          train_bn=not args.ffe)

        if (epoch + 1) % args.val_period == 0:
            natural_acc = test(net, testloader, device)
            net.eval()
            p_acc = (net(target_img.unsqueeze(0).to(device)).max(1)[1].item()
                     == poisoned_label)
            print(
                now(),
                " Epoch: ",
                epoch,
                ", Loss: ",
                loss,
                ", Training acc: ",
                acc,
                ", Natural accuracy: ",
                natural_acc,
                ", poison success: ",
                p_acc,
            )

    # test
    natural_acc = test(net, testloader, device)
    print(now(), " Training ended at epoch ", epoch, ", Natural accuracy: ",
          natural_acc)
    net.eval()
    p_acc = net(
        target_img.unsqueeze(0).to(device)).max(1)[1].item() == poisoned_label

    print(
        now(),
        " poison success: ",
        p_acc,
        " poisoned_label: ",
        poisoned_label,
        " prediction: ",
        net(target_img.unsqueeze(0).to(device)).max(1)[1].item(),
    )

    # Dictionary to write contest the csv file
    stats = OrderedDict([
        ("poisons path", args.poisons_path),
        ("model",
         args.model_path if args.model_path is not None else args.model),
        ("target class", target_class),
        ("base class", poisoned_label),
        ("num poisons", len(poison_tuples)),
        ("max perturbation norm", np.max(poison_perturbation_norms)),
        ("epoch", epoch),
        ("loss", loss),
        ("training_acc", acc),
        ("natural_acc", natural_acc),
        ("poison_acc", p_acc),
    ])
    to_results_table(stats, args.output)
    ####################################################

    return
Пример #5
0
def main(args):
    """Main function to generate the FC poisons
    inputs:
        args:           Argparse object
    reutrn:
        void
    """
    print(now(), "craft_poisons_fc.py main() running.")

    craft_log = "craft_log.txt"
    to_log_file(args, args.output, craft_log)

    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"

    ####################################################
    #               Dataset
    if args.dataset.lower() == "cifar10":
        transform_test = get_transform(args.normalize, False)
        trainset = torchvision.datasets.CIFAR10(root="./data",
                                                train=True,
                                                download=True,
                                                transform=transform_test)
        testset = torchvision.datasets.CIFAR10(root="./data",
                                               train=False,
                                               download=True,
                                               transform=transform_test)
    elif args.dataset.lower() == "tinyimagenet_first":
        transform_test = get_transform(args.normalize,
                                       False,
                                       dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="firsthalf",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="firsthalf",
        )
    elif args.dataset.lower() == "tinyimagenet_last":
        transform_test = get_transform(args.normalize,
                                       False,
                                       dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="lasthalf",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="lasthalf",
        )
    elif args.dataset.lower() == "tinyimagenet_all":
        transform_test = get_transform(args.normalize,
                                       False,
                                       dataset=args.dataset)
        trainset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="train",
            transform=transform_test,
            classes="all",
        )
        testset = TinyImageNet(
            TINYIMAGENET_ROOT,
            split="val",
            transform=transform_test,
            classes="all",
        )
    else:
        print("Dataset not yet implemented. Exiting from craft_poisons_fc.py.")
        sys.exit()
    ###################################################

    ####################################################
    #          Craft and insert poison image
    feature_extractors = []
    for i in range(len(args.model)):
        feature_extractors.append(
            load_model_from_checkpoint(args.model[i], args.model_path[i],
                                       args.pretrain_dataset))

    for i in range(len(feature_extractors)):
        for param in feature_extractors[i].parameters():
            param.requires_grad = False
        feature_extractors[i].eval()
        feature_extractors[i] = feature_extractors[i].to(device)

    with open(args.poison_setups, "rb") as handle:
        setup_dicts = pickle.load(handle)
    setup = setup_dicts[args.setup_idx]

    target_img_idx = (setup["target index"]
                      if args.target_img_idx is None else args.target_img_idx)
    base_indices = (setup["base indices"]
                    if args.base_indices is None else args.base_indices)
    # Craft poisons
    poison_iterations = args.crafting_iters
    poison_perturbation_norms = []

    # get single target
    target_img, target_label = testset[target_img_idx]

    # get multiple bases
    base_imgs = torch.stack([trainset[i][0] for i in base_indices])
    base_labels = torch.LongTensor([trainset[i][1] for i in base_indices])

    # log target and base details
    to_log_file("base indices: " + str(base_indices), args.output, craft_log)
    to_log_file("base labels: " + str(base_labels), args.output, craft_log)
    to_log_file("target_label: " + str(target_label), args.output, craft_log)
    to_log_file("target_index: " + str(target_img_idx), args.output, craft_log)

    # fill list of tuples of poison images and labels
    poison_tuples = []
    target_img = (un_normalize_data(target_img, args.dataset)
                  if args.normalize else target_img)
    beta = 4.0 if args.normalize else 0.1

    base_tuples = list(zip(base_imgs, base_labels))
    for base_img, label in base_tuples:
        # unnormalize the images for optimization
        b_unnormalized = (un_normalize_data(base_img, args.dataset)
                          if args.normalize else base_img)
        objective_vals = [10e8]
        step_size = args.step_size

        # watermarking
        x = copy.deepcopy(b_unnormalized)
        x = args.watermark_coeff * target_img + (1 - args.watermark_coeff) * x

        # feature collision optimization
        done_with_fc = False
        i = 0
        while not done_with_fc and i < poison_iterations:
            x.requires_grad = True
            if args.normalize:
                mini_batch = torch.stack([
                    normalize_data(x, args.dataset),
                    normalize_data(target_img, args.dataset),
                ]).to(device)
            else:
                mini_batch = torch.stack([x, target_img]).to(device)

            loss = 0
            for feature_extractor in feature_extractors:
                feats = feature_extractor.penultimate(mini_batch)
                loss += torch.norm(feats[0, :] - feats[1, :])**2
            grad = torch.autograd.grad(loss, [x])[0]
            x_hat = x.detach() - step_size * grad.detach()
            if not args.l2:
                pert = (x_hat - b_unnormalized).clamp(-args.epsilon,
                                                      args.epsilon)
                x_new = b_unnormalized + pert
                x_new = x_new.clamp(0, 1)
                obj = loss

            else:
                x_new = (x_hat.detach() + step_size * beta *
                         b_unnormalized.detach()) / (1 + step_size * beta)
                x_new = x_new.clamp(0, 1)
                obj = beta * torch.norm(x_new - b_unnormalized)**2 + loss

            if obj > objective_vals[-1]:
                step_size *= 0.2

            else:
                if torch.norm(x - x_new) / torch.norm(x) < 1e-5:
                    done_with_fc = True
                x = copy.deepcopy(x_new)
                objective_vals.append(obj)
            i += 1
        poison_tuples.append((transforms.ToPILImage()(x), label.item()))
        poison_perturbation_norms.append(
            torch.max(torch.abs(x - b_unnormalized)).item())
        x.requires_grad = False
    ####################################################

    ####################################################
    #        Save Poisons
    print(now(), "Saving poisons...")
    if not os.path.isdir(args.poisons_path):
        os.makedirs(args.poisons_path)
    with open(os.path.join(args.poisons_path, "poisons.pickle"),
              "wb") as handle:
        pickle.dump(poison_tuples, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args.poisons_path, "perturbation_norms.pickle"),
              "wb") as handle:
        pickle.dump(poison_perturbation_norms,
                    handle,
                    protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args.poisons_path, "base_indices.pickle"),
              "wb") as handle:
        pickle.dump(base_indices, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open(os.path.join(args.poisons_path, "target.pickle"),
              "wb") as handle:
        pickle.dump(
            (transforms.ToPILImage()(target_img), target_label),
            handle,
            protocol=pickle.HIGHEST_PROTOCOL,
        )
    to_log_file("poisons saved.", args.output, craft_log)
    ####################################################

    print(now(), "craft_poisons_fc.py done.")
    return