예제 #1
0
 def test_run_resnet101(self):
     imgSize = (3, 224, 224)
     # should throw because privacy engine does not work with batch norm
     # remove the next two lines when we support batch norm
     with self.assertRaises(IncompatibleModuleException):
         self.runOneBatch(models.resnet101(), imgSize)
     self.runOneBatch(mm.convert_batchnorm_modules(models.resnet101()), imgSize)
예제 #2
0
 def test_run_basic_case(self):
     imgSize = (3, 4, 5)
     # should throw because privacy engine does not work with batch norm
     # remove the next two lines when we support batch norm
     with self.assertRaises(IncompatibleModuleException):
         self.runOneBatch(BasicModel(imgSize), imgSize)
     self.runOneBatch(mm.convert_batchnorm_modules(BasicModel(imgSize)), imgSize)
예제 #3
0
    def __init__(self, device=None, jit=False):
        super().__init__()
        self.device = device
        self.jit = jit

        self.model = models.resnet18(num_classes=10)
        self.model = convert_batchnorm_modules(self.model)
        self.model = self.model.to(device)

        self.example_inputs = (
            torch.randn((64, 3, 32, 32), device=self.device),
        )
        self.example_target = torch.randint(0, 10, (64,), device=self.device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()

        # This is supposed to equal the number of data points.
        # It is only to compute stats so dwai about the value.
        sample_size = 64 * 100
        clipping = {"clip_per_layer": False, "enable_stat": False}
        self.privacy_engine = PrivacyEngine(
            self.model,
            batch_size=64,
            sample_size=sample_size,
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=1.0,
            max_grad_norm=1.0,
            secure_rng=False,
            **clipping,
        )
        self.privacy_engine.attach(self.optimizer)
예제 #4
0
 def test_convert_batchnorm_modules_resnet50(self):
     model = models.resnet50()
     # check module BatchNorms is there
     self.checkModulePresent(model, nn.BatchNorm2d)
     # replace the module with instancenorm
     model = mm.convert_batchnorm_modules(model)
     # check module is not present
     self.checkModuleNotPresent(model, nn.BatchNorm2d)
     self.checkModulePresent(model, nn.GroupNorm)
예제 #5
0
    def test_module_modification_convert_example(self):
        # IMPORTANT: When changing this code you also need to update
        # the docstring for opacus.utils.module_modification.convert_batchnorm_modules()
        from torchvision.models import resnet50

        model = resnet50()
        self.assertTrue(isinstance(model.layer1[0].bn1, nn.BatchNorm2d))

        model = convert_batchnorm_modules(model)
        self.assertTrue(isinstance(model.layer1[0].bn1, nn.GroupNorm))
예제 #6
0
def get_model(private=True, cache= dirname + '/.cache/'):
    model = convert_batchnorm_modules(torchvision.models.resnet18(num_classes=10))

    if private:
        data = torch.load('./models/privatemodel_best.pth.tar')
    else:
        data = torch.load('./models/nonprivatemodel_best.pth.tar')

    model.load_state_dict(data['state_dict'])
    model.cuda()
    model.eval()

    return model
예제 #7
0
def get_trained_model(dataset, dp):
    classes, trainloader, testloader, trainset, testset = get_test_train_loaders(
        dataset)
    device = torch.device("cuda")
    net = torchvision.models.alexnet(num_classes=len(classes)).to(device)

    if dp:
        net = module_modification.convert_batchnorm_modules(net)
    inspector = DPModelInspector()

    print(f"Model valid:        {inspector.validate(net)}")
    print(f"Model trained DP:   {dp}")
    net = net.to(device)

    if dp:
        PATH = './trained_models/' + dataset + '_dp' + '.pth'
    else:
        PATH = './trained_models/' + dataset + '.pth'
    # PATH='./trained_models/mnist_0_931.pth'
    # PATH='./trained_models/mnist_0_dp_316.pth'
    print(len(classes))
    net.load_state_dict(torch.load(PATH + "",
                                   map_location=torch.device('cpu')))
    return net, classes, trainloader, testloader, trainset, testset
예제 #8
0
        elif args.dataset == 'dr':
            global_model = models.resnet50(num_classes=100)
            global_model.to(device)
            summary(global_model, input_size=(3, 32, 32), device=device)
    else:
        exit('Error: unrecognized model')
    ############# Common ###################

    ######### DP Model Compatibility #######
    if args.withDP:
        try:
            inspector = DPModelInspector()
            inspector.validate(global_model)
            print("Model's already Valid!\n")
        except:
            global_model = module_modification.convert_batchnorm_modules(
                global_model)
            inspector = DPModelInspector()
            print(f"Is the model valid? {inspector.validate(global_model)}")
            print("Model is convereted to be Valid!\n")
    ######### DP Model Compatibility #######

    ######### Local Models and Optimizers #############
    local_models = []
    local_optimizers = []
    local_privacy_engine = []

    for u in range(args.num_users):
        local_models.append(copy.deepcopy(global_model))

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(local_models[u].parameters(),
 def test_convert_batchnorm(self):
     inspector = dp_inspector.DPModelInspector()
     model = convert_batchnorm_modules(models.resnet50())
     self.assertTrue(inspector.validate(model))
예제 #10
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
    parser.add_argument(
        "-j",
        "--workers",
        default=0,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 2)",
    )
    parser.add_argument(
        "--epochs",
        default=90,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--start-epoch",
        default=1,
        type=int,
        metavar="N",
        help="manual epoch number (useful on restarts)",
    )
    parser.add_argument(
        "-b",
        "--batch-size-test",
        default=256,
        type=int,
        metavar="N",
        help="mini-batch size for test dataset (default: 256), this is the total "
        "batch size of all GPUs on the current node when "
        "using Data Parallel or Distributed Data Parallel",
    )
    parser.add_argument(
        "--sample-rate",
        default=0.04,
        type=float,
        metavar="SR",
        help="sample rate used for batch construction (default: 0.005)",
    )
    parser.add_argument(
        "-na",
        "--n_accumulation_steps",
        default=1,
        type=int,
        metavar="N",
        help="number of mini-batches to accumulate into an effective batch",
    )
    parser.add_argument(
        "--lr",
        "--learning-rate",
        default=0.1,
        type=float,
        metavar="LR",
        help="initial learning rate",
        dest="lr",
    )
    parser.add_argument(
        "--momentum", default=0.9, type=float, metavar="M", help="SGD momentum"
    )
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=5e-4,
        type=float,
        metavar="W",
        help="SGD weight decay",
        dest="weight_decay",
    )
    parser.add_argument(
        "-p",
        "--print-freq",
        default=10,
        type=int,
        metavar="N",
        help="print frequency (default: 10)",
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "-e",
        "--evaluate",
        dest="evaluate",
        action="store_true",
        help="evaluate model on validation set",
    )
    parser.add_argument(
        "--seed", default=None, type=int, help="seed for initializing training. "
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="GPU ID for this process (default: 'cuda')",
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=1.5,
        metavar="S",
        help="Noise multiplier (default 1.0)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=10.0,
        metavar="C",
        help="Clip per-sample gradients to this norm (default 1.0)",
    )
    parser.add_argument(
        "--secure-rng",
        action="store_true",
        default=False,
        help="Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )

    parser.add_argument(
        "--checkpoint-file",
        type=str,
        default="checkpoint",
        help="path to save check points",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="../cifar10",
        help="Where CIFAR10 is/will be stored",
    )
    parser.add_argument(
        "--log-dir", type=str, default="", help="Where Tensorboard log will be stored"
    )
    parser.add_argument(
        "--optim",
        type=str,
        default="SGD",
        help="Optimizer to use (Adam, RMSprop, SGD)",
    )
    parser.add_argument(
        "--lr-schedule", type=str, choices=["constant", "cos"], default="cos"
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank if multi-GPU training, -1 for single GPU training",
    )
    # New added args
    parser.add_argument(
        "--model-name",
        type=str,
        default="ConvNet",
        help="Name of the model structure",
    )
    parser.add_argument(
        "--results-folder",
        type=str,
        default="../results/cifar10",
        help="Where CIFAR10 results is/will be stored",
    )
    parser.add_argument(
        "--sub-training-size",
        type=int,
        default=0,
        help="Size of bagging",
    )
    parser.add_argument(
        "-r",
        "--n-runs",
        type=int,
        default=1,
        metavar="R",
        help="number of runs to average on (default: 1)",
    )
    parser.add_argument(
        "--save-model",
        action="store_true",
        default=False,
        help="Save the trained model (default: false)",
    )
    parser.add_argument(
        "--run-test",
        action="store_true",
        default=False,
        help="Run test for the model (default: false)",
    )
    parser.add_argument(
        "--load-model",
        action="store_true",
        default=False,
        help="Load model not train (default: false)",
    )
    parser.add_argument(
        "--train-mode",
        type=str,
        default="DP",
        help="Train mode: DP, Sub-DP, Bagging",
    )
    parser.add_argument(
        "--sub-acc-test",
        action="store_true",
        default=False,
        help="Test subset V.S. acc (default: false)",
    )

    args = parser.parse_args()

    # folder path
    result_folder = result_folder_path_generator(args)
    print(f'Result folder: {result_folder}')
    models_folder = f"{result_folder}/models"
    Path(models_folder).mkdir(parents=True, exist_ok=True)

    # logging
    logging.basicConfig(filename=f"{result_folder}/train.log", filemode='w', level=logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler())

    # distributed = False
    # if args.local_rank != -1:
    #     setup()
    #     distributed = True

    # Sets `world_size = 1` if you run on a single GPU with `args.local_rank = -1`
    if args.device != "cpu":
        rank, local_rank, world_size = setup(args)
        device = local_rank
    else:
        device = "cpu"
        rank = 0
        world_size = 1


    if args.train_mode == 'Bagging' and args.n_accumulation_steps > 1:
        raise ValueError("Virtual steps only works with enabled DP")

    # The following lines enable stat gathering for the clipping process
    # and set a default of per layer clipping for the Privacy Engine
    clipping = {"clip_per_layer": False, "enable_stat": True}

    generator = None

    augmentations = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]

    train_transform = transforms.Compose(
        augmentations + normalize if args.train_mode == 'Bagging' else normalize
    )
    test_transform = transforms.Compose(normalize)

    def gen_sub_dataset(dataset, sub_training_size, with_replacement):
        indexs = np.random.choice(len(dataset), sub_training_size, replace=with_replacement)
        dataset = torch.utils.data.Subset(dataset, indexs)
        print(f"Sub-dataset size {len(dataset)}")
        return dataset

    def gen_train_dataset_loader(sub_training_size):
        train_dataset = CIFAR10(
            root=args.data_root, train=True, download=True, transform=train_transform
        )

        if args.train_mode == 'Sub-DP' or args.train_mode == 'Bagging':
            train_dataset = gen_sub_dataset(train_dataset, sub_training_size, True)
        
        batch_num = None

        if world_size > 1:
            dist_sampler = DistributedSampler(train_dataset)
        else:
            dist_sampler = None

        # batch_size = int(args.sample_rate * len(train_dataset))

        # if args.train_mode == 'DP' or args.train_mode == 'Sub-DP':
        #     train_loader = torch.utils.data.DataLoader(
        #         train_dataset,
        #         num_workers=args.workers,
        #         generator=generator,
        #         batch_size=batch_size,
        #         sampler=dist_sampler,
        #     )
        # elif args.train_mode == 'Sub-DP-no-amp':
        #     train_loader = torch.utils.data.DataLoader(
        #         train_dataset,
        #         num_workers=args.workers,
        #         generator=generator,
        #         batch_size=batch_size,
        #         sampler=dist_sampler,
        #     )
        #     batch_num = int(sub_training_size / int(args.sample_rate * len(train_dataset)))
        # else:
        #     print('No Gaussian Sampler')
        #     train_loader = torch.utils.data.DataLoader(
        #         train_dataset,
        #         num_workers=args.workers,
        #         generator=generator,
        #         batch_size=int(128/world_size),
        #         sampler=dist_sampler,
        #     )
        # return train_dataset, train_loader, batch_num
        
        if args.train_mode == 'DP' or args.train_mode == 'Sub-DP':
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                num_workers=args.workers,
                generator=generator,
                batch_sampler=FixedSizedUniformWithReplacementSampler(
                    num_samples=len(train_dataset),
                    sample_rate=args.sample_rate/world_size,
                    train_size=len(train_dataset)/world_size,
                    generator=generator,
                ),
            )
        elif args.train_mode == 'Sub-DP-no-amp':
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                num_workers=args.workers,
                generator=generator,
                batch_sampler=FixedSizedUniformWithReplacementSampler(
                    num_samples=len(train_dataset),
                    sample_rate=args.sample_rate/world_size,
                    train_size=sub_training_size/world_size,
                    generator=generator,
                ),
            )
        else:
            print('No Gaussian Sampler')
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                num_workers=args.workers,
                generator=generator,
                batch_size=int(256/world_size),
                sampler=dist_sampler,
            )
        return train_dataset, train_loader, batch_num

    def gen_test_dataset_loader():
        test_dataset = CIFAR10(
            root=args.data_root, train=False, download=True, transform=test_transform
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.batch_size_test,
            shuffle=False,
            num_workers=args.workers,
        )
        return test_dataset, test_loader
    

    # if distributed and args.device == "cuda":
    #     args.device = "cuda:" + str(args.local_rank)
    # device = torch.device(args.device)

    """ Here we go the training and testing process """
    
    # collect votes from all models
    test_dataset, test_loader = gen_test_dataset_loader()
    aggregate_result = np.zeros([len(test_dataset), 10 + 1], dtype=np.int)
    aggregate_result_softmax = np.zeros([args.n_runs, len(test_dataset), 10 + 1], dtype=np.float32)
    acc_list = []

    # use this code for "sub_training_size V.S. acc"
    if args.sub_acc_test:
        sub_acc_list = []
    
    for run_idx in range(args.n_runs):
        # Pre-training stuff for each base classifier
        
        # Define the model
        if args.model_name == 'ConvNet':
            model = convnet(num_classes=10)
        elif args.model_name == 'ResNet18-BN':
            model = ResNet18(10)
            # model = module_modification.convert_batchnorm_modules(models.resnet18(pretrained=False, num_classes=10))
            # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            # model = models.resnet18(pretrained=False, num_classes=10)
            # model = module_modification.convert_batchnorm_modules(ResNet18(10))
        elif args.model_name == 'ResNet18-GN':
            model = module_modification.convert_batchnorm_modules(ResNet18(10))
        elif args.model_name == 'LeNet':
            model = LeNet()
        else:
            exit(f'Model name {args.model_name} invaild.')
        model = model.to(device)
        
        if world_size > 1:
            if not args.train_mode == 'Bagging':
                model = DPDDP(model)
            else:
                # model = DDP(model, device_ids=[args.local_rank])
                model = DDP(model, device_ids=[device])
        
        # Define the optimizer
        if args.optim == "SGD":
            optimizer = optim.SGD(
                model.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
            )
        elif args.optim == "RMSprop":
            optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
        elif args.optim == "Adam":
            optimizer = optim.Adam(model.parameters(), lr=args.lr)
        else:
            raise NotImplementedError("Optimizer not recognized. Please check spelling")

        # Define the DP engine
        if args.train_mode != 'Bagging':
            privacy_engine = PrivacyEngine(
                model,
                sample_rate=args.sample_rate * args.n_accumulation_steps / world_size,
                alphas=[1 + x / 10.0 for x in range(1, 100)],
                noise_multiplier= 0.0 if args.train_mode == 'Bagging' else args.sigma,
                max_grad_norm=args.max_per_sample_grad_norm,
                secure_rng=args.secure_rng,
                **clipping,
            )
            privacy_engine.attach(optimizer)
        
        # Training and testing
        model_pt_file = f"{models_folder}/model_{run_idx}.pt"
        def training_process():
            logging.info(f'training model_{run_idx}...')
            # use this code for "sub_training_size V.S. acc"
            if args.sub_acc_test:
                sub_training_size = int(50000 - 50000 / args.n_runs * run_idx)
                _, train_loader, batch_num = gen_train_dataset_loader(sub_training_size)
            else:    
                sub_training_size = args.sub_training_size
                _, train_loader, batch_num = gen_train_dataset_loader(sub_training_size)

            epoch_acc_epsilon = []
            for epoch in range(args.start_epoch, args.epochs + 1):
                if args.lr_schedule == "cos":
                    lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch / (args.epochs + 1)))
                    for param_group in optimizer.param_groups:
                        param_group["lr"] = lr

                train(args, model, train_loader, optimizer, epoch, device, batch_num)

                # if args.run_test:
                #     logging.info(f'Epoch: {epoch}')
                #     test(args, model, test_loader, device)

                # if run_idx == 0:
                #     logging.info(f'Epoch: {epoch}')
                #     acc = test(args, model, test_loader, device)
                #     if args.train_mode in ['DP', 'Sub-DP', 'Sub-DP-no-amp']:
                #         eps, _ = optimizer.privacy_engine.get_privacy_spent(args.delta)
                #         epoch_acc_epsilon.append((acc, eps))
            
            
            if run_idx == 0:
                np.save(f"{result_folder}/epoch_acc_eps", epoch_acc_epsilon)
            
            # Post-training stuff 

            # use this code for "sub_training_size V.S. acc"
            if args.sub_acc_test:
                sub_acc_list.append((sub_training_size, test(args, model, test_loader, device)))

            # save the DP related data
            if run_idx == 0 and args.train_mode in ['DP', 'Sub-DP', 'Sub-DP-no-amp']:
                rdp_alphas, rdp_epsilons = optimizer.privacy_engine.get_rdp_privacy_spent()
                dp_epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(args.delta)
                rdp_steps = optimizer.privacy_engine.steps
                logging.info(f"epsilon {dp_epsilon}, best_alpha {best_alpha}, steps {rdp_steps}")

                logging.info(f"sample_rate {optimizer.privacy_engine.sample_rate}, noise_multiplier {optimizer.privacy_engine.noise_multiplier}, steps {optimizer.privacy_engine.steps}")
                
                np.save(f"{result_folder}/rdp_epsilons", rdp_epsilons)
                np.save(f"{result_folder}/rdp_alphas", rdp_alphas)
                np.save(f"{result_folder}/rdp_steps", rdp_steps)
                np.save(f"{result_folder}/dp_epsilon", dp_epsilon)
        
        if os.path.isfile(model_pt_file) or args.load_model:
            try:
                torch.distributed.barrier()
                logging.info(f'loading existing model_{run_idx}...')
                map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank}
                model.load_state_dict(torch.load(model_pt_file, map_location=map_location))
            except Exception as inst:
                logging.info(f'fail to load model with error: {inst}')
                training_process()
        else:
            training_process()
        
        # save preds and model
        aggregate_result[np.arange(0, len(test_dataset)), pred(args, model, test_loader, device)] += 1
        aggregate_result_softmax[run_idx, np.arange(0, len(test_dataset)), 0:10] = softmax(args, model, test_loader, device)
        acc_list.append(test(args, model, test_loader, device))
        if not args.load_model and args.save_model and local_rank == 0:
            torch.save(model.state_dict(), model_pt_file)

    # Finish trining all models, save results
    aggregate_result[np.arange(0, len(test_dataset)), -1] = next(iter(torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset))))[1]
    aggregate_result_softmax[:, np.arange(0, len(test_dataset)), -1] = next(iter(torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset))))[1]
    np.save(f"{result_folder}/aggregate_result", aggregate_result)
    np.save(f"{result_folder}/aggregate_result_softmax", aggregate_result_softmax)
    np.save(f"{result_folder}/acc_list", acc_list)

    # use this code for "sub_training_size V.S. acc"
    if args.sub_acc_test:
        np.save(f"{result_folder}/subset_acc_list", sub_acc_list)

    if world_size > 1:
        cleanup()
예제 #11
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument(
        "-sr",
        "--sample-rate",
        type=float,
        default=0.001,
        metavar="SR",
        help="sample rate used for batch construction (default: 0.001)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1024,
        metavar="TB",
        help="input batch size for testing (default: 1024)",
    )
    parser.add_argument(
        "-n",
        "--epochs",
        type=int,
        default=10,
        metavar="N",
        help="number of epochs to train (default: 14)",
    )
    parser.add_argument(
        "-r",
        "--n-runs",
        type=int,
        default=1,
        metavar="R",
        help="number of runs to average on (default: 1)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.1,
        metavar="LR",
        help="learning rate (default: .1)",
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=1.0,
        metavar="S",
        help="Noise multiplier (default 1.0)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=1.0,
        metavar="C",
        help="Clip per-sample gradients to this norm (default 1.0)",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="GPU ID for this process (default: 'cuda')",
    )
    parser.add_argument(
        "--save-model",
        action="store_true",
        default=False,
        help="Save the trained model (default: false)",
    )
    parser.add_argument(
        "--run-test",
        action="store_true",
        default=False,
        help="Run test for the model (default: false)",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="../mnist",
        help="Where MNIST is/will be stored",
    )
    parser.add_argument(
        "--results-folder",
        type=str,
        default="../results/fashion_mnist",
        help="Where MNIST results is/will be stored",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="SampleConvNet",
        help="Name of the model",
    )
    parser.add_argument(
        "--sub-training-size",
        type=int,
        default=0,
        help="Size of bagging",
    )
    parser.add_argument(
        "--load-model",
        action="store_true",
        default=False,
        help="Load model not train (default: false)",
    )
    parser.add_argument(
        "--sub-acc-test",
        action="store_true",
        default=False,
        help="Test subset V.S. acc (default: false)",
    )
    parser.add_argument(
        "--train-mode",
        type=str,
        default="DP",
        help="Mode of training: DP, Bagging, Sub-DP",
    )
    args = parser.parse_args()
    device = torch.device(args.device)
    kwargs = {"num_workers": 1, "pin_memory": True}

    def gen_sub_dataset(dataset, sub_training_size, with_replacement):
        indexs = np.random.choice(len(dataset), sub_training_size, replace=with_replacement)
        dataset = torch.utils.data.Subset(dataset, indexs)
        print(f"Sub-dataset size {len(dataset)}")
        return dataset

    def gen_train_dataset_loader(or_sub_training_size=None):
        train_dataset = datasets.FashionMNIST(
            args.data_root,
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)),
                ]
            ),
        )

        # Generate sub-training dataset if necessary
        sub_training_size = args.sub_training_size if or_sub_training_size is None else or_sub_training_size
        if args.train_mode == 'Bagging' or args.train_mode == 'Sub-DP':
            train_dataset = gen_sub_dataset(train_dataset, sub_training_size, True)

        if args.train_mode == 'DP' or args.train_mode == 'Sub-DP':
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                generator=None,
                batch_sampler=UniformWithReplacementSampler(
                    num_samples=len(train_dataset),
                    sample_rate=args.sample_rate,
                    generator=None,
                ),
                **kwargs,
            )
        elif args.train_mode == 'Sub-DP-no-amp':
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                generator=None,
                batch_sampler=FixedSizedUniformWithReplacementSampler(
                    num_samples=len(train_dataset),
                    sample_rate=args.sample_rate,
                    train_size=sub_training_size,
                    generator=None,
                ),
                **kwargs,
            )
        else:
            print('No Gaussian Sampler')
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=64,
                shuffle=True,
                **kwargs,
            )
        return train_dataset, train_loader

    def gen_test_dataset_loader():
        test_dataset = datasets.FashionMNIST(
            args.data_root,
            train=False,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize((MNIST_MEAN,), (MNIST_STD,)),
                ]
            ),
        )

        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs,
        )

        return test_dataset, test_loader

    # folder for this experiment
    result_folder = result_folder_path_generator(args)
    print(f'Result folder: {result_folder}')
    models_folder = f"{result_folder}/models"
    Path(models_folder).mkdir(parents=True, exist_ok=True)

    # log file for this experiment
    logging.basicConfig(
        filename=f"{result_folder}/train.log", filemode='w', level=logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler())

    # collect votes from all models
    test_dataset, test_loader = gen_test_dataset_loader()
    aggregate_result = np.zeros([len(test_dataset), 10 + 1], dtype=np.int)
    aggregate_result_softmax = np.zeros([args.n_runs, len(test_dataset), 10 + 1], dtype=np.float32)
    acc_list = []
    
    # use this code for "sub_training_size V.S. acc"
    if args.sub_acc_test:
        sub_acc_list = []

    for run_idx in range(args.n_runs):
        # pre-training stuff
        if args.model_name == 'SampleConvNet':
            model = SampleConvNet().to(device)
        elif args.model_name == 'ResNet18':
            model = module_modification.convert_batchnorm_modules(models.resnet18(pretrained=False, num_classes=10))
            # change input layer
            # the default number of input channel in the resnet is 3, but our images are 1 channel. So we have to change 3 to 1.
            # nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) <- default
            model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
            model.to(device)
        elif args.model_name == 'LeNet':
            model = LeNet().to(device)
        else:
            logging.warn(f"Model name {args.model_name} invaild.")
            exit()
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0)
        
        if args.train_mode != 'Bagging':
            privacy_engine = PrivacyEngine(
                model,
                sample_rate=args.sample_rate,
                alphas=[
                    1 + x / 10.0 for x in range(1, 100)] + list(range(12, 1500)),
                noise_multiplier=args.sigma,
                max_grad_norm=args.max_per_sample_grad_norm,
            )
            privacy_engine.attach(optimizer)

        # training
        if args.load_model:
            model.load_state_dict(torch.load(
                f"{models_folder}/model_{run_idx}.pt"))
        else:
            # use this code for "sub_training_size V.S. acc"
            if args.sub_acc_test:
                sub_training_size = int(60000 - 60000 / args.n_runs * run_idx)
                _, train_loader = gen_train_dataset_loader(sub_training_size)
            else:
                _, train_loader = gen_train_dataset_loader()

            epoch_acc_epsilon = []
            for epoch in range(1, args.epochs + 1):
                train(args, model, device, train_loader, optimizer, epoch)
                if args.run_test:
                    logging.info(f'Epoch: {epoch}')
                    test(args, model, device, test_loader)
                if run_idx == 0:
                    logging.info(f'Epoch: {epoch}')
                    acc = test(args, model, device, test_loader)
                    if args.train_mode in ['DP', 'Sub-DP', 'Sub-DP-no-amp']:
                        eps, _ = optimizer.privacy_engine.get_privacy_spent(args.delta)
                        epoch_acc_epsilon.append((acc, eps))    
            if run_idx == 0:
                np.save(f"{result_folder}/epoch_acc_eps", epoch_acc_epsilon)
        
        # use this code for "sub_training_size V.S. acc"
        if args.sub_acc_test:
            sub_acc_list.append((sub_training_size, test(args, model, device, test_loader)))

        # post-training stuff
        if run_idx == 0 and args.train_mode in ['DP', 'Sub-DP', 'Sub-DP-no-amp']:
            rdp_alphas, rdp_epsilons = optimizer.privacy_engine.get_rdp_privacy_spent()
            dp_epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
                args.delta)
            rdp_steps = optimizer.privacy_engine.steps
            logging.info(
                f"epsilon {dp_epsilon}, best_alpha {best_alpha}, steps {rdp_steps}")
            print(
                f"epsilon {dp_epsilon}, best_alpha {best_alpha}, steps {rdp_steps}")

            np.save(f"{result_folder}/rdp_epsilons", rdp_epsilons)
            np.save(f"{result_folder}/rdp_alphas", rdp_alphas)
            np.save(f"{result_folder}/rdp_steps", rdp_steps)
            np.save(f"{result_folder}/dp_epsilon", dp_epsilon)
        # save preds
        aggregate_result[np.arange(0, len(test_dataset)), pred(
            args, model, test_dataset, device).cpu()] += 1
        aggregate_result_softmax[run_idx, np.arange(0, len(test_dataset)), 0:10] = softmax(args, model, test_dataset, device).cpu().detach().numpy()
        acc_list.append(test(args, model, device, test_loader))
        # save model
        if not args.load_model and args.save_model:
            torch.save(model.state_dict(),
                       f"{models_folder}/model_{run_idx}.pt")
    # finish trining all models, save results
    aggregate_result[np.arange(0, len(test_dataset)), -1] = next(
        iter(torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset))))[1]
    aggregate_result_softmax[:, np.arange(0, len(test_dataset)), -1] = next(iter(torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset))))[1]
    np.save(f"{result_folder}/aggregate_result_softmax", aggregate_result_softmax)
    np.save(f"{result_folder}/aggregate_result", aggregate_result)
    np.save(f"{result_folder}/acc_list", acc_list)

    # use this code for "sub_training_size V.S. acc"
    if args.sub_acc_test:
        np.save(f"{result_folder}/subset_acc_list", sub_acc_list)
예제 #12
0
import torch
import torch.nn as nn
import torchvision.models as models
from opacus.utils.module_modification import convert_batchnorm_modules
import time

from functorch import vmap, grad
from functorch import make_functional
from opacus import PrivacyEngine

device = 'cuda'
batch_size = 128
torch.manual_seed(0)

model_functorch = convert_batchnorm_modules(models.resnet18(num_classes=10))
model_functorch = model_functorch.to(device)
criterion = nn.CrossEntropyLoss()

images = torch.randn(batch_size, 3, 32, 32, device=device)
targets = torch.randint(0, 10, (batch_size, ), device=device)
func_model, weights = make_functional(model_functorch)


def compute_loss(weights, image, target):
    images = image.unsqueeze(0)
    targets = target.unsqueeze(0)
    output = func_model(weights, images)
    loss = criterion(output, targets)
    return loss

예제 #13
0
# number of batches after which it is time to do validation
validation_batches = ceil(len(dataloader) / args.validations_per_epoch)

print(f"Using sigma={args.sigma} and C={args.max_grad_norm}")

if args.delta > 0:

    privacy_engine = PrivacyEngine(
        model,
        batch_size=batch_size,
        sample_size=len(train_dataset),
        alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
        noise_multiplier=args.sigma,
        max_grad_norm=args.max_grad_norm,
    )
    model = module_modification.convert_batchnorm_modules(model).cuda()
    inspector = DPModelInspector()
    print(inspector.validate(model), "--------------")

    privacy_engine.attach(optimizer)

start_epoch = 0
# a list of validation IWAE estimates
validation_iwae = []
# a list of running variational lower bounds on the train set
train_vlb = []
# the length of two lists above is the same because the new
# values are inserted into them at the validation checkpoints only

# load the last checkpoint, if it exists
if exists(join(args.model_dir, 'last_checkpoint_{}.tar'.format(args.exp))):
예제 #14
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
    parser.add_argument(
        "-j",
        "--workers",
        default=2,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 2)",
    )
    parser.add_argument(
        "--epochs",
        default=90,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--start-epoch",
        default=1,
        type=int,
        metavar="N",
        help="manual epoch number (useful on restarts)",
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        # This should be 256, but that OOMs using the prototype.
        default=64,
        type=int,
        metavar="N",
        help="mini-batch size (default: 64), this is the total "
        "batch size of all GPUs on the current node when "
        "using Data Parallel or Distributed Data Parallel",
    )
    parser.add_argument(
        "-na",
        "--n_accumulation_steps",
        default=1,
        type=int,
        metavar="N",
        help="number of mini-batches to accumulate into an effective batch",
    )
    parser.add_argument(
        "--lr",
        "--learning-rate",
        default=0.001,
        type=float,
        metavar="LR",
        help="initial learning rate",
        dest="lr",
    )
    parser.add_argument("--momentum",
                        default=0.9,
                        type=float,
                        metavar="M",
                        help="SGD momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=5e-4,
        type=float,
        metavar="W",
        help="SGD weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "-p",
        "--print-freq",
        default=10,
        type=int,
        metavar="N",
        help="print frequency (default: 10)",
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "-e",
        "--evaluate",
        dest="evaluate",
        action="store_true",
        help="evaluate model on validation set",
    )
    parser.add_argument("--seed",
                        default=None,
                        type=int,
                        help="seed for initializing training. ")
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="GPU ID for this process (default: 'cuda')",
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=1.0,
        metavar="S",
        help="Noise multiplier (default 1.0)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=1.0,
        metavar="C",
        help="Clip per-sample gradients to this norm (default 1.0)",
    )
    parser.add_argument(
        "--disable-dp",
        action="store_true",
        default=False,
        help="Disable privacy training and just train with vanilla SGD",
    )
    parser.add_argument(
        "--secure-rng",
        action="store_true",
        default=False,
        help=
        "Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )

    parser.add_argument(
        "--checkpoint-file",
        type=str,
        default="checkpoint",
        help="path to save check points",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="../cifar10",
        help="Where CIFAR10 is/will be stored",
    )
    parser.add_argument("--log-dir",
                        type=str,
                        default="",
                        help="Where Tensorboard log will be stored")
    parser.add_argument(
        "--optim",
        type=str,
        default="Adam",
        help="Optimizer to use (Adam, RMSprop, SGD)",
    )

    args = parser.parse_args()
    args.disable_dp = True

    if args.disable_dp and args.n_accumulation_steps > 1:
        raise ValueError("Virtual steps only works with enabled DP")

    # The following few lines, enable stats gathering about the run
    # 1. where the stats should be logged
    stats.set_global_summary_writer(
        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir)))
    # 2. enable stats
    stats.add(
        # stats about gradient norms aggregated for all layers
        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
        # stats about gradient norms per layer
        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
        # stats about clipping
        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
        # stats on training accuracy
        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
        # stats on validation accuracy
        stats.Stat(stats.StatType.TEST, "accuracy"),
    )

    # The following lines enable stat gathering for the clipping process
    # and set a default of per layer clipping for the Privacy Engine
    clipping = {"clip_per_layer": False, "enable_stat": True}

    if args.secure_rng:
        assert False
        try:
            import torchcsprng as prng
        except ImportError as e:
            msg = (
                "To use secure RNG, you must install the torchcsprng package! "
                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
            )
            raise ImportError(msg) from e

        generator = prng.create_random_device_generator("/dev/urandom")

    else:
        generator = None

    augmentations = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ]
    train_transform = transforms.Compose(
        augmentations + normalize if args.disable_dp else normalize)

    test_transform = transforms.Compose(normalize)

    train_dataset = CIFAR10(root=args.data_root,
                            train=True,
                            download=True,
                            transform=train_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        drop_last=True,
        generator=generator,
    )

    test_dataset = CIFAR10(root=args.data_root,
                           train=False,
                           download=True,
                           transform=test_transform)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
    )

    best_acc1 = 0
    device = torch.device(args.device)
    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
    # model = CIFAR10Model()
    model = model.to(device)

    if args.optim == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(
            "Optimizer not recognized. Please check spelling")

    if not args.disable_dp:
        privacy_engine = PrivacyEngine(
            model,
            batch_size=args.batch_size * args.n_accumulation_steps,
            sample_size=len(train_dataset),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=args.sigma,
            max_grad_norm=args.max_per_sample_grad_norm,
            secure_rng=args.secure_rng,
            **clipping,
        )
        privacy_engine.attach(optimizer)

    for epoch in range(args.start_epoch, args.epochs + 1):
        train(args, model, train_loader, optimizer, epoch, device)
        top1_acc = test(args, model, test_loader, device)

        # remember best acc@1 and save checkpoint
        is_best = top1_acc > best_acc1
        best_acc1 = max(top1_acc, best_acc1)

        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": "ResNet18",
                "state_dict": model.state_dict(),
                "best_acc1": best_acc1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            filename=args.checkpoint_file + ".tar",
        )
예제 #15
0
def main():
    parser = argparse.ArgumentParser(description="Trainer")
    parser.add_argument(
        "-b",
        "--batch-size",
        type=int,
        default=64,
        metavar="B",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "-na",
        "--n_accumulation_steps",
        default=1,
        type=int,
        metavar="N",
        help="number of mini-batches to accumulate into an effective batch",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=200,
        metavar="TB",
        help="input batch size for testing (default: 1024)",
    )
    parser.add_argument(
        "-d",
        "--data",
        type=str,
        default="mnist",
        metavar="D",
        help="dataset to train on (mnist, fashion, cifar10, cifar100"
    )
    parser.add_argument(
        "-n",
        "--epochs",
        type=int,
        default=10,
        metavar="N",
        help="number of epochs to train (default: 14)",
    )
    parser.add_argument(
        "-r",
        "--n-runs",
        type=int,
        default=1,
        metavar="R",
        help="number of runs to average on (default: 1)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.1,
        metavar="LR",
        help="learning rate (default: .1)",
    )
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=0.,
        type=float,
        metavar="W",
        help="SGD weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "--momentum", default=0., type=float, metavar="M", help="SGD momentum"
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=1.0,
        metavar="S",
        help="Noise multiplier (default 1.0)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=1.0,
        metavar="C",
        help="Clip per-sample gradients to this norm (default 1.0)",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="GPU ID for this process (default: 'cuda')",
    )
    parser.add_argument(
        "--save-model",
        action="store_true",
        default=False,
        help="Save the trained model (default: false)",
    )
    parser.add_argument(
        "--disable-dp",
        action="store_true",
        default=False,
        help="Disable privacy training and just train with vanilla SGD",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="./data",
        help="Where MNIST is/will be stored",
    )

    args = parser.parse_args()
    device = torch.device(args.device)

    kwargs = {"num_workers": 1, "pin_memory": True}

    train_set, test_set, train_loader, test_loader = SangExp.get_data(
        data=args.data,
        batch_size=args.batch_size,
        test_batch_size=args.test_batch_size,
        download_path=args.data_root
    )

    
    if args.data == "mnist" or args.data == "fashion":
        model = ConvNet()
    elif args.data == "cifar10":
        model = resnet18(num_classes=10)
    elif args.data == "cifar100":
        model = resnet18(num_classes=100)

    model = convert_batchnorm_modules(model).to(device)


    optimizer = optim.SGD(
        model.parameters(),
        lr = args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    if not args.disable_dp:
        privacy_engine = PrivacyEngine(
            model,
            batch_size=args.batch_size * args.n_accumulation_steps,
            sample_size=len(train_set),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=args.sigma,
            max_grad_norm=args.max_per_sample_grad_norm,
            clip_per_layer=False,
            enable_stat=False
        )
        privacy_engine.attach(optimizer)

    stats = []
    for epoch in range(args.epochs):
        stat = []
        if not args.disable_dp:
            epsilon, best_alpha = train(args, model, device, train_loader, optimizer, epoch)
            stat.append(epsilon)
            stat.append(best_alpha)
        else:
            train(args, model, device, train_loader, optimizer, epoch)
        acc = test(args, model, device, test_loader)
        stat.append(acc)
        stats.append(tuple(stat))

    name = "save/{}".format(args.data)
    if not args.disable_dp:
        name += "_dp"
    np.save(name, stats)
    torch.save(model.state_dict(), name+".pt")
예제 #16
0
파일: dcgan.py 프로젝트: zhuyawen/opacus
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input,
                                               range(self.ngpu))
        else:
            output = self.main(input)
        return output


netG = Generator(ngpu)
if not opt.disable_dp:
    netG = convert_batchnorm_modules(netG)
netG = netG.to(device)
netG.apply(weights_init)
if opt.netG != "":
    netG.load_state_dict(torch.load(opt.netG))


class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
예제 #17
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
    parser.add_argument(
        "-j",
        "--workers",
        default=2,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 2)",
    )
    parser.add_argument(
        "--epochs",
        default=100,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--start-epoch",
        default=1,
        type=int,
        metavar="N",
        help="manual epoch number (useful on restarts)",
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        default=256,
        type=int,
        metavar="N",
        help="mini-batch size (default: 256), this is the total "
        "batch size of all GPUs on the current node when "
        "using Data Parallel or Distributed Data Parallel",
    )
    parser.add_argument(
        "-na",
        "--n_accumulation_steps",
        default=1,
        type=int,
        metavar="N",
        help="number of mini-batches to accumulate into an effective batch",
    )
    parser.add_argument(
        "--lr",
        "--learning-rate",
        default=0.001,
        type=float,
        metavar="LR",
        help="initial learning rate",
        dest="lr",
    )
    parser.add_argument("--momentum",
                        default=0.9,
                        type=float,
                        metavar="M",
                        help="SGD momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=5e-4,
        type=float,
        metavar="W",
        help="SGD weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "-p",
        "--print-freq",
        default=5,
        type=int,
        metavar="N",
        help="print frequency (default: 10)",
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "-e",
        "--evaluate",
        dest="evaluate",
        action="store_true",
        help="evaluate model on validation set",
    )
    parser.add_argument("--seed",
                        default=None,
                        type=int,
                        help="seed for initializing training. ")
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="GPU ID for this process (default: 'cuda')",
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=0.001,
        metavar="S",
        help="Noise multiplier (default 1.0)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=100.0,
        metavar="C",
        help="Clip per-sample gradients to this norm (default 1.0)",
    )
    parser.add_argument(
        "--disable-dp",
        action="store_true",
        default=False,
        help="Disable privacy training and just train with vanilla SGD",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )

    parser.add_argument(
        "--checkpoint-file",
        type=str,
        default="checkpoint",
        help="path to save check points",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="../cifar10",
        help="Where CIFAR10 is/will be stored",
    )
    parser.add_argument("--log-dir",
                        type=str,
                        default="",
                        help="Where Tensorboard log will be stored")
    parser.add_argument(
        "--optim",
        type=str,
        default="Adam",
        help="Optimizer to use (Adam, RMSprop, SGD)",
    )
    parser.add_argument('--save_path',
                        type=str,
                        default='/content/drive/My Drive/resnet18')

    args = parser.parse_args()

    # The following few lines, enable stats gathering about the run
    # 1. where the stats should be logged
    stats.set_global_summary_writer(
        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir)))
    # 2. enable stats
    stats.add(
        # stats about gradient norms aggregated for all layers
        stats.Stat(stats.StatType.CLIPPING, "AllLayers", frequency=0.1),
        # stats about gradient norms per layer
        stats.Stat(stats.StatType.CLIPPING, "PerLayer", frequency=0.1),
        # stats about clipping
        stats.Stat(stats.StatType.CLIPPING, "ClippingStats", frequency=0.1),
        # stats on training accuracy
        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
        # stats on validation accuracy
        stats.Stat(stats.StatType.TEST, "accuracy"),
    )

    # The following lines enable stat gathering for the clipping process
    # and set a default of per layer clipping for the Privacy Engine
    clipping = {"clip_per_layer": False, "enable_stat": True}
    augmentations = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
        # transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    ]

    train_transform = transforms.Compose(
        augmentations + normalize if args.disable_dp else normalize)

    test_transform = transforms.Compose(normalize)

    train_dataset = CIFAR10(root=args.data_root,
                            train=True,
                            download=True,
                            transform=train_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        drop_last=True,
    )

    test_dataset = CIFAR10(root=args.data_root,
                           train=False,
                           download=True,
                           transform=test_transform)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
    )

    best_acc1 = 0
    device = torch.device(args.device)
    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
    model = model.to(device)

    if args.optim == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(
            "Optimizer not recognized. Please check spelling")

    if not args.disable_dp:
        privacy_engine = PrivacyEngine(
            model,
            batch_size=args.batch_size * args.n_accumulation_steps,
            sample_size=len(train_dataset),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=args.sigma,
            max_grad_norm=args.max_per_sample_grad_norm,
            **clipping,
        )
        privacy_engine.attach(optimizer)

    for epoch in range(args.start_epoch, args.epochs + 1):
        if not args.disable_dp:
            epsilon, best_alpha = train(args, model, train_loader, optimizer,
                                        epoch, device)
        else:
            train(args, model, train_loader, optimizer, epoch, device)
        top1_acc = test(args, model, test_loader, device)

        # remember best acc@1 and save checkpoint
        is_best = top1_acc > best_acc1
        best_acc1 = max(top1_acc, best_acc1)

        if not args.disable_dp and epoch % 5 == 0:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'epoch': epoch,
                    'epsilon': epsilon,
                    'best_alpha': best_alpha,
                    'accuracy': top1_acc
                },
                os.path.join(args.save_path,
                             f"resnet18_cifar10_dp_{epoch}.tar"))

        # else:
        #     save_checkpoint(
        #         {
        #             "epoch": epoch,
        #             "arch": "ResNet18",
        #             "state_dict": model.state_dict(),
        #             "best_acc1": best_acc1,
        #             "optimizer": optimizer.state_dict(),
        #         },
        #         is_best,
        #         filename=args.checkpoint_file + ".tar",
        #     )

    if args.disable_dp:
        torch.save(model.state_dict(),
                   os.path.join(args.save_path, f'resnet18_cifar10.pt'))