tt = 0
    else:
        print(args)
        print("Path: {}".format(chk_path))
        # Otherwise, we craft new poisons
        if args.nearest:
            base_tensor_list, base_idx_list = fetch_nearest_poison_bases(
                sub_net_list, target, args.poison_num, args.poison_label,
                args.num_per_class, 'others', args.train_data_path,
                transform_test)

        else:
            # just fetch the first poison_num samples
            base_tensor_list, base_idx_list = fetch_poison_bases(
                args.poison_label,
                args.poison_num,
                subset='others',
                path=args.train_data_path,
                transforms=transform_test)
        base_tensor_list = [bt.to('cuda') for bt in base_tensor_list]
        print("Selected base image indices: {}".format(base_idx_list))

        if args.resume_poison_ite > 0:
            state_dict = torch.load(
                os.path.join(chk_path,
                             "poison_%05d.pth" % args.resume_poison_ite))
            poison_tuple_list, base_idx_list = state_dict[
                'poison'], state_dict['idx']
            poison_init = [pt.to('cuda') for pt, _ in poison_tuple_list]
            # re-direct the results to the resumed dir...
            chk_path += '-resume'
            if not os.path.exists(chk_path):
def get_poison_perturbation(poisons_root_path,
                            ites,
                            target_ids,
                            poison_label,
                            poison_num=5,
                            train_data_path='datasets/CIFAR10_TRAIN_Split.pth',
                            device='cpu'):
    import torch
    import torchvision
    import torchvision.transforms as transforms

    cifar_mean = (0.4914, 0.4822, 0.4465)
    cifar_std = (0.2023, 0.1994, 0.2010)
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar_mean, cifar_std),
    ])
    cifar_mean = torch.Tensor(cifar_mean).reshape((1, 3, 1, 1))
    cifar_std = torch.Tensor(cifar_std).reshape((1, 3, 1, 1))

    base_tensor_list, base_idx_list = fetch_poison_bases(
        poison_label,
        poison_num,
        subset='others',
        path=train_data_path,
        transforms=transform_test)
    base_tensor_batch = torch.stack(base_tensor_list, 0)
    base_range01_batch = (base_tensor_batch * cifar_std + cifar_mean).view(
        (poison_num, -1))

    linf = []
    l2 = []
    for ite in ites:
        linf_tmp = []
        l2_tmp = []
        for t_id in target_ids:
            path = '{}/{}/poison_{}.pth'.format(poisons_root_path, t_id,
                                                "%.5d" % (int(ite) - 1))
            assert os.path.exists(path)
            if device == 'cuda':
                state_dict = torch.load(path)
            elif device == 'cpu':
                state_dict = torch.load(path, map_location=torch.device('cpu'))
            poison_tuple_list, idx_list = state_dict['poison'], state_dict[
                'idx']
            poison_tuple_list = [p for p, _ in poison_tuple_list]
            poison_batch = torch.stack(poison_tuple_list, 0)
            poison_range01_batch = (poison_batch * cifar_std +
                                    cifar_mean).view((poison_num, -1))

            for i, ii in zip(idx_list, base_idx_list):
                assert i == ii
            diff = poison_range01_batch - base_range01_batch
            abs_diff = torch.abs(diff)
            linf_diff = torch.max(abs_diff, dim=1).values
            max_perturb = torch.max(linf_diff).item()
            assert max_perturb <= EPSILON + 1e-5, "WHAT THE F**K, WE HAVE L-inf perturbation of {}".format(
                max_perturb)

            # linf_tmp.append(torch.mean(linf_diff).item())
            # l2_tmp.append(torch.mean(torch.norm(diff, dim=1)).item())
            linf_tmp.append(sorted(linf_diff.tolist()))
            l2_tmp.append(sorted(torch.norm(diff, dim=1).tolist()))

        linf_sum = []
        l2_sum = []
        for idx in range(poison_num):
            linf_sum.append(sum([l[idx] for l in linf_tmp]) / len(linf_tmp))
            l2_sum.append(sum([l[idx] for l in l2_tmp]) / len(l2_tmp))
        linf.append(linf_sum)
        l2.append(l2_sum)
    return linf, l2