dst = datasets.CIFAR100("~/.torch", download=True)
tp = transforms.ToTensor()
tt = transforms.ToPILImage()

img_index = args.index
gt_data = tp(dst[img_index][0]).to(device)

if len(args.image) > 1:
    gt_data = Image.open(args.image)
    gt_data = tp(gt_data).to(device)


gt_data = gt_data.view(1, *gt_data.size())
gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
gt_label = gt_label.view(1, )
gt_onehot_label = label_to_onehot(gt_label)

plt.imshow(tt(gt_data[0].cpu()))

from models.vision import LeNet, weights_init
net = LeNet().to(device)

net.apply(weights_init)
criterion = cross_entropy_for_onehot

# compute original gradient 
pred = net(gt_data)
y = criterion(pred, gt_onehot_label)
dy_dx = torch.autograd.grad(y, net.parameters())

original_dy_dx = list((_.detach().clone() for _ in dy_dx))
def main(args):
    # Set GPUs and device
    os.environ['CUDA_VISIBLE_DEVICES'] = args.visible_gpus
    device = 'cuda' if torch.cuda.is_available else 'cpu'
    print('Running on %s' % device)

    # Set environment
    torch.manual_seed(args.seed)

    # Get dataset and define transformations
    dst = datasets.CIFAR100(args.data_dir, download=True)
    tp = transforms.Compose([
        transforms.Resize(32),
        transforms.CenterCrop(32),
        transforms.ToTensor()
    ])
    tt = transforms.ToPILImage()

    # Construct model and intiaize weights
    net = models.LeNet().to(device)
    net.apply(utils.weights_init)

    # Define criterion
    criterion = utils.cross_entropy_for_onehot

    # Get attack data and label
    gt_data = tp(dst[args.image_idx][0]).to(device)
    gt_data = gt_data.view(1, *gt_data.size())
    gt_image = tt(gt_data[0].cpu())
    gt_label = torch.Tensor([dst[args.image_idx][1]]).long().to(device)
    gt_label = gt_label.view(1, )
    gt_onehot_label = utils.label_to_onehot(gt_label, num_classes=100)

    # Compute original gradient
    out = net(gt_data)
    y = criterion(out, gt_onehot_label)
    dy_dx = torch.autograd.grad(y, net.parameters())

    # Share the gradients with other clients
    original_dy_dx = list((_.detach().clone() for _ in dy_dx))

    # Generate dummy data and label
    dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
    dummy_label = torch.randn(
        gt_onehot_label.size()).to(device).requires_grad_(True)

    dummy_init_image = tt(dummy_data[0].cpu())
    dummy_init_label = torch.argmax(dummy_label, dim=-1)

    # Define optimizer
    optimizer = torch.optim.LBFGS([dummy_data, dummy_label])

    # Run DLG method
    dummy_grads, dummy_lbfgs_num_iter, history = \
        dlg_method(dummy_data, dummy_label, original_dy_dx,
                   net, criterion, optimizer, tt, max_iters=args.max_iters)

    # Save model
    params_path = os.path.join(args.exp_dir,
                               '%04d_params.pkl' % args.image_idx)
    torch.save(net.state_dict(), params_path)
    print('Save model parameters to %s' % params_path)

    # Index computation functions
    compute_l2norm = lambda x: (x**2).sum().item()**0.5
    compute_min = lambda x: x.min().item()
    compute_max = lambda x: x.max().item()
    compute_mean = lambda x: x.mean().item()
    compute_median = lambda x: x.median().item()

    original_grads_norm = [compute_l2norm(e) for e in original_dy_dx]
    original_grads_min = [compute_min(e) for e in original_dy_dx]
    original_grads_max = [compute_max(e) for e in original_dy_dx]
    original_grads_mean = [compute_mean(e) for e in original_dy_dx]
    original_grads_median = [compute_median(e) for e in original_dy_dx]

    dummy_grads_norm = np.array([[compute_l2norm(e) for e in r]
                                 for r in dummy_grads])
    dummy_grads_min = np.array([[compute_min(e) for e in r]
                                for r in dummy_grads])
    dummy_grads_max = np.array([[compute_max(e) for e in r]
                                for r in dummy_grads])
    dummy_grads_mean = np.array([[compute_mean(e) for e in r]
                                 for r in dummy_grads])
    dummy_grads_median = np.array([[compute_median(e) for e in r]
                                   for r in dummy_grads])

    # Plot and save figures
    fig_dir = os.path.join(args.exp_dir, 'figures')
    if not os.path.exists(fig_dir):
        os.makedirs(fig_dir)

    img_history = [[tt(hd), hl] for hd, hl in history]
    utils.plot_history([gt_image, gt_label],
                       [dummy_init_image, dummy_init_label],
                       img_history,
                       title='Image %04d History' % args.image_idx,
                       fig_path=os.path.join(
                           fig_dir, '%04d_history.png' % args.image_idx))

    utils.plot_convergency_curve(
        dummy_grads_norm,
        original_grads_norm,
        title='Image %04d L2 Norm Convergency' % args.image_idx,
        fig_path=os.path.join(fig_dir, '%04d_l2norm.png' % args.image_idx))

    utils.plot_convergency_curve(
        dummy_grads_min,
        original_grads_min,
        title='Image %04d Min Value Convergency' % args.image_idx,
        fig_path=os.path.join(fig_dir, '%04d_min.png' % args.image_idx))

    utils.plot_convergency_curve(
        dummy_grads_max,
        original_grads_max,
        title='Image %04d Max Value Convergency' % args.image_idx,
        fig_path=os.path.join(fig_dir, '%04d_max.png' % args.image_idx))

    utils.plot_convergency_curve(
        dummy_grads_mean,
        original_grads_mean,
        title='Image %04d Mean Convergency' % args.image_idx,
        fig_path=os.path.join(fig_dir, '%04d_mean.png' % args.image_idx))

    utils.plot_convergency_curve(
        dummy_grads_median,
        original_grads_median,
        title='Image %04d Median Value Convergency' % args.image_idx,
        fig_path=os.path.join(fig_dir, '%04d_median.png' % args.image_idx))

    compute_mse = lambda x, y: ((x - y)**2).sum().item()**0.5
    final_mse = compute_mse(dummy_data, gt_data)
    converged = final_mse < args.mse_tol
    print('Converged!! (MSE=%2.6f)' %
          final_mse if converged else 'Diverged!! (MSE=%2.6f)' % final_mse)

    # Save MSEs
    mses = np.array([compute_mse(hd.cuda(), gt_data) for hd, _ in history])
    with open(os.path.join(args.exp_dir, '%04d_mses.npy' % args.image_idx),
              'wb') as opf:
        np.save(opf, mses)
Esempio n. 3
0
    [transforms.Resize(32),
     transforms.CenterCrop(32),
     transforms.ToTensor()])
tt = transforms.ToPILImage()

img_index = args.index
# the target for deep leakage
gt_data = tp(dst[img_index][0]).to(device)
# Data range is from 0 to 1.

gt_data = gt_data.view(1, *gt_data.size())
print(gt_data.shape)
gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
print("Ground Truth Label is %d" % gt_label.item())
gt_label = gt_label.view(1, )
gt_onehot_label = label_to_onehot(gt_label, num_classes=10)

plt.imshow(tt(gt_data[0].cpu()))
from models.vision import LeNet, weights_init, ResNet18, ResNet34, AlexNet
if args.arch == 'LeNet5':
    net = LeNet(nchannel, nclass, args.act).to(device)
    net.apply(weights_init)
elif args.arch == 'ResNet18':
    net = ResNet18(nclass, nchannel, args.act).to(device)
    net.apply(weights_init)
elif args.arch == 'ResNet34':
    net = ResNet34(nclass, nchannel, args.act).to(device)
    net.apply(weights_init)
elif args.arch == 'AlexNet':
    net = AlexNet(nclass=nclass, in_channels=nchannel, act=args.act).to(device)
else:
Esempio n. 4
0
trainset = dataset(transform=transform, test=False)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=2)
testset = dataset(transform=transform, test=True)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=2)

if args.train:
    for i, (image, label) in enumerate(trainloader):
        gt_data = image.to(device)
        gt_onehot_label = label_to_onehot(label.to(device))

        # compute original gradient
        pred = net(gt_data)
        y = criterion(pred, gt_onehot_label)
        dy_dx = torch.autograd.grad(y, net.parameters())
        original_dy_dx = list((_.detach().clone() for _ in dy_dx))
        best_loss = 1e5
        while best_loss > 5:
            # generate dummy data and label
            dummy_data = torch.randn(
                gt_data.size()).to(device).requires_grad_(True)
            dummy_label = torch.randn(
                gt_onehot_label.size()).to(device).requires_grad_(True)

            optimizer = torch.optim.LBFGS([dummy_data, dummy_label])
Esempio n. 5
0
def DoesRestoreDummyImageFromGivenOriginalGradientSuccessful(net, weights_init, dst, indexNo,addNoiseOrNot="N", setVariation=0.1, showImageOnScreen="N", lossThresholdValueToJudgeInCorrect=1.00, judge_by_category_classification="Y"):
    gt_data = tp(dst[indexNo][0]).to(device)
    gt_data = gt_data.view(1, *gt_data.size())
    gt_label = torch.Tensor([dst[indexNo][1]]).long().to(device)
    gt_label = gt_label.view(1, )
    gt_onehot_label = label_to_onehot(gt_label)
    plt.imshow(tt(gt_data[0].cpu()))

    # from models.vision import LeNet, weights_init
    #
    # net = LeNet().to(device)

    torch.manual_seed(1234)

    net.apply(weights_init)
    criterion = cross_entropy_for_onehot

    # compute original gradient
    pred = net(gt_data)
    y = criterion(pred, gt_onehot_label)
    if(addNoiseOrNot.lower() == "Y".lower()):
        setVariation = float(setVariation)
        addGausianNoise(net, [setVariation])

    dy_dx = torch.autograd.grad(y, net.parameters())

    original_dy_dx = list((_.detach().clone() for _ in dy_dx))

    # generate dummy data and label
    dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
    dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)

    plt.imshow(tt(dummy_data[0].cpu()))

    optimizer = torch.optim.LBFGS([dummy_data, dummy_label])

    history = []
    dummy_label_copy = []
    dummy_pred_copy = []
    dummy_onehot_label_copy = []
    dummy_loss_copy = []
    dummy_dy_dx_copy = []
    lastDummyLossData = 0

    for iters in range(300):
        def closure():
            nonlocal dummy_label_copy
            nonlocal dummy_pred_copy
            nonlocal dummy_onehot_label_copy
            nonlocal dummy_loss_copy
            nonlocal dummy_dy_dx_copy
            nonlocal lastDummyLossData
            optimizer.zero_grad()

            dummy_pred = net(dummy_data)
            dummy_onehot_label = F.softmax(dummy_label, dim=-1)
            dummy_loss = criterion(dummy_pred, dummy_onehot_label)
            dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)
            # dummy_loss_copy  = dummy_loss

            grad_diff = 0
            for gx, gy in zip(dummy_dy_dx, original_dy_dx):
                grad_diff += ((gx - gy) ** 2).sum()
            grad_diff.backward()

            dummy_label_copy = dummy_label
            dummy_pred_copy = dummy_pred
            dummy_onehot_label_copy = dummy_onehot_label
            dummy_loss_copy = dummy_loss
            dummy_dy_dx_copy = dummy_dy_dx

            return grad_diff

        optimizer.step(closure)
        if iters % 10 == 0:
            current_loss = closure()
            print(iters, "%.4f" % current_loss.item())
            lastDummyLossData = current_loss.item()
            history.append(tt(dummy_data[0].cpu()))

    plt.figure(figsize=(12, 8))
    for i in range(30):
        plt.subplot(3, 10, i + 1)
        plt.imshow(history[i])
        plt.title("iter=%d" % (i * 10))
        plt.axis('off')

    result = AreDummyAndOriginalLabelNoMatch(gt_onehot_label, dummy_onehot_label_copy)
    if(showImageOnScreen.lower() == "Y".lower()):
        plt.show()
        plt.close('all')
    if(lastDummyLossData > lossThresholdValueToJudgeInCorrect and (judge_by_category_classification.lower() == "N".lower())):
        print("Accuracy is being built by Dummy loss too")
        print("Dummy loss is still greater than the threshold."+ str(lossThresholdValueToJudgeInCorrect) +" Image restoration failed by the value " + str(lastDummyLossData))
        result = 0

    return int(result)