def test_stats_example(self):
        # IMPORTANT: When changing this code you also need to update
        # the docstrings for opacus.utils.stats.Stat
        class MockSummaryWriter:
            def __init__(self):
                self.logs = defaultdict(dict)

            def add_scalar(self, name, value, iter):
                self.logs[name][iter] = value

        mock_summary_writer = MockSummaryWriter()
        stats.set_global_summary_writer(mock_summary_writer)

        stat = stats.Stat(stats.StatType.GRAD, "sample_stats", frequency=0.1)
        for i in range(21):
            stat.log({"val": i})

        self.assertEqual(len(mock_summary_writer.logs["GRAD:sample_stats/val"]), 2)

        stats.add(stats.Stat(stats.StatType.TEST, "accuracy", frequency=1.0))
        stats.update(stats.StatType.TEST, acc1=1.0)
Пример #2
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
    parser.add_argument(
        "-j",
        "--workers",
        default=2,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 2)",
    )
    parser.add_argument(
        "--epochs",
        default=90,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--start-epoch",
        default=1,
        type=int,
        metavar="N",
        help="manual epoch number (useful on restarts)",
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        # This should be 256, but that OOMs using the prototype.
        default=64,
        type=int,
        metavar="N",
        help="mini-batch size (default: 64), this is the total "
        "batch size of all GPUs on the current node when "
        "using Data Parallel or Distributed Data Parallel",
    )
    parser.add_argument(
        "-na",
        "--n_accumulation_steps",
        default=1,
        type=int,
        metavar="N",
        help="number of mini-batches to accumulate into an effective batch",
    )
    parser.add_argument(
        "--lr",
        "--learning-rate",
        default=0.001,
        type=float,
        metavar="LR",
        help="initial learning rate",
        dest="lr",
    )
    parser.add_argument("--momentum",
                        default=0.9,
                        type=float,
                        metavar="M",
                        help="SGD momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=5e-4,
        type=float,
        metavar="W",
        help="SGD weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "-p",
        "--print-freq",
        default=10,
        type=int,
        metavar="N",
        help="print frequency (default: 10)",
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "-e",
        "--evaluate",
        dest="evaluate",
        action="store_true",
        help="evaluate model on validation set",
    )
    parser.add_argument("--seed",
                        default=None,
                        type=int,
                        help="seed for initializing training. ")
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="GPU ID for this process (default: 'cuda')",
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=1.0,
        metavar="S",
        help="Noise multiplier (default 1.0)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=1.0,
        metavar="C",
        help="Clip per-sample gradients to this norm (default 1.0)",
    )
    parser.add_argument(
        "--disable-dp",
        action="store_true",
        default=False,
        help="Disable privacy training and just train with vanilla SGD",
    )
    parser.add_argument(
        "--secure-rng",
        action="store_true",
        default=False,
        help=
        "Enable Secure RNG to have trustworthy privacy guarantees. Comes at a performance cost",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )

    parser.add_argument(
        "--checkpoint-file",
        type=str,
        default="checkpoint",
        help="path to save check points",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="../cifar10",
        help="Where CIFAR10 is/will be stored",
    )
    parser.add_argument("--log-dir",
                        type=str,
                        default="",
                        help="Where Tensorboard log will be stored")
    parser.add_argument(
        "--optim",
        type=str,
        default="Adam",
        help="Optimizer to use (Adam, RMSprop, SGD)",
    )

    args = parser.parse_args()
    args.disable_dp = True

    if args.disable_dp and args.n_accumulation_steps > 1:
        raise ValueError("Virtual steps only works with enabled DP")

    # The following few lines, enable stats gathering about the run
    # 1. where the stats should be logged
    stats.set_global_summary_writer(
        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir)))
    # 2. enable stats
    stats.add(
        # stats about gradient norms aggregated for all layers
        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
        # stats about gradient norms per layer
        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
        # stats about clipping
        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
        # stats on training accuracy
        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
        # stats on validation accuracy
        stats.Stat(stats.StatType.TEST, "accuracy"),
    )

    # The following lines enable stat gathering for the clipping process
    # and set a default of per layer clipping for the Privacy Engine
    clipping = {"clip_per_layer": False, "enable_stat": True}

    if args.secure_rng:
        assert False
        try:
            import torchcsprng as prng
        except ImportError as e:
            msg = (
                "To use secure RNG, you must install the torchcsprng package! "
                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
            )
            raise ImportError(msg) from e

        generator = prng.create_random_device_generator("/dev/urandom")

    else:
        generator = None

    augmentations = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ]
    train_transform = transforms.Compose(
        augmentations + normalize if args.disable_dp else normalize)

    test_transform = transforms.Compose(normalize)

    train_dataset = CIFAR10(root=args.data_root,
                            train=True,
                            download=True,
                            transform=train_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        drop_last=True,
        generator=generator,
    )

    test_dataset = CIFAR10(root=args.data_root,
                           train=False,
                           download=True,
                           transform=test_transform)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
    )

    best_acc1 = 0
    device = torch.device(args.device)
    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
    # model = CIFAR10Model()
    model = model.to(device)

    if args.optim == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(
            "Optimizer not recognized. Please check spelling")

    if not args.disable_dp:
        privacy_engine = PrivacyEngine(
            model,
            batch_size=args.batch_size * args.n_accumulation_steps,
            sample_size=len(train_dataset),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=args.sigma,
            max_grad_norm=args.max_per_sample_grad_norm,
            secure_rng=args.secure_rng,
            **clipping,
        )
        privacy_engine.attach(optimizer)

    for epoch in range(args.start_epoch, args.epochs + 1):
        train(args, model, train_loader, optimizer, epoch, device)
        top1_acc = test(args, model, test_loader, device)

        # remember best acc@1 and save checkpoint
        is_best = top1_acc > best_acc1
        best_acc1 = max(top1_acc, best_acc1)

        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": "ResNet18",
                "state_dict": model.state_dict(),
                "best_acc1": best_acc1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            filename=args.checkpoint_file + ".tar",
        )
Пример #3
0
def main():

    args = parse_args()

    if args.debug >= 1:
        logger.setLevel(level=logging.DEBUG)

    # Sets `world_size = 1` if you run on a single GPU with `args.local_rank = -1`
    if args.device != "cpu":
        rank, local_rank, world_size = setup(args)
        device = local_rank
    else:
        device = "cpu"
        rank = 0
        world_size = 1

    if args.disable_dp and args.n_accumulation_steps > 1:
        raise ValueError("Virtual steps only works with enabled DP")

    if args.dist_algo == "ddp_hook" and not args.clip_per_layer:
        raise ValueError(
            "Please enable `--clip_per_layer` if you want to use Opacus DDP")

    # The following few lines, enable stats gathering about the run
    # 1. where the stats should be logged
    stats.set_global_summary_writer(tensorboard.SummaryWriter(args.log_dir))
    # 2. enable stats
    stats.add(
        # stats about gradient norms aggregated for all layers
        stats.Stat(stats.StatType.GRAD, "AllLayers", frequency=0.1),
        # stats about gradient norms per layer
        stats.Stat(stats.StatType.GRAD, "PerLayer", frequency=0.1),
        # stats about clipping
        stats.Stat(stats.StatType.GRAD, "ClippingStats", frequency=0.1),
        # stats on training accuracy
        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
        # stats on validation accuracy
        stats.Stat(stats.StatType.TEST, "accuracy"),
    )

    # The following lines enable stat gathering for the clipping process
    # and set a default of per layer clipping for the Privacy Engine
    clipping = {
        "clip_per_layer": args.clip_per_layer,
        "enable_stat": (rank == 0),
    }

    if args.secure_rng:
        try:
            import torchcsprng as prng
        except ImportError as e:
            msg = (
                "To use secure RNG, you must install the torchcsprng package! "
                "Check out the instructions here: https://github.com/pytorch/csprng#installation"
            )
            raise ImportError(msg) from e

        generator = prng.create_random_device_generator("/dev/urandom")

    else:
        generator = None

    augmentations = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ]
    train_transform = transforms.Compose(
        augmentations + normalize if args.disable_dp else normalize)

    test_transform = transforms.Compose(normalize)

    train_dataset = CIFAR10(root=args.data_root,
                            train=True,
                            download=True,
                            transform=train_transform)

    if world_size > 1:
        train_sampler = DistributedPoissonBatchSampler(
            total_size=len(train_dataset),
            sample_rate=args.sample_rate,
            num_replicas=world_size,
            rank=rank,
            generator=generator,
        )

    else:
        train_sampler = UniformWithReplacementSampler(
            num_samples=len(train_dataset),
            sample_rate=args.sample_rate,
            generator=generator,
        )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        generator=generator,
        num_workers=args.workers,
        pin_memory=True,
    )

    test_dataset = CIFAR10(root=args.data_root,
                           train=False,
                           download=True,
                           transform=test_transform)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size_test,
        shuffle=False,
        num_workers=args.workers,
    )

    best_acc1 = 0

    model = convnet(num_classes=10)
    model = model.to(device)

    # Use the right distributed module wrapper if distributed training is enabled
    if world_size > 1:
        if not args.disable_dp:
            if args.dist_algo == "naive":
                model = DPDDP(model)
            elif args.dist_algo == "ddp_hook":
                model = DDP(model, device_ids=[device])
            else:
                raise NotImplementedError(
                    f"Unrecognized argument for the distributed algorithm: {args.dist_algo}"
                )
        else:
            model = DDP(model, device_ids=[device])

    if args.optim == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(
            "Optimizer not recognized. Please check spelling")

    if not args.disable_dp:
        if args.clip_per_layer:
            # Each layer has the same clipping threshold. The total grad norm is still bounded by `args.max_per_sample_grad_norm`.
            n_layers = len([(n, p) for n, p in model.named_parameters()
                            if p.requires_grad])
            max_grad_norm = [
                args.max_per_sample_grad_norm / np.sqrt(n_layers)
            ] * n_layers
        else:
            max_grad_norm = args.max_per_sample_grad_norm

        privacy_engine = PrivacyEngine(
            model,
            sample_rate=args.sample_rate * args.n_accumulation_steps,
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=args.sigma,
            max_grad_norm=max_grad_norm,
            secure_rng=args.secure_rng,
            **clipping,
        )
        privacy_engine.attach(optimizer)

    # Store some logs
    accuracy_per_epoch = []
    time_per_epoch = []

    for epoch in range(args.start_epoch, args.epochs + 1):
        if args.lr_schedule == "cos":
            lr = args.lr * 0.5 * (1 + np.cos(np.pi * epoch /
                                             (args.epochs + 1)))
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr

        train_duration = train(args, model, train_loader, optimizer, epoch,
                               device)
        top1_acc = test(args, model, test_loader, device)

        # remember best acc@1 and save checkpoint
        is_best = top1_acc > best_acc1
        best_acc1 = max(top1_acc, best_acc1)

        time_per_epoch.append(train_duration)
        accuracy_per_epoch.append(float(top1_acc))

        save_checkpoint(
            {
                "epoch": epoch + 1,
                "arch": "Convnet",
                "state_dict": model.state_dict(),
                "best_acc1": best_acc1,
                "optimizer": optimizer.state_dict(),
            },
            is_best,
            filename=args.checkpoint_file + ".tar",
        )

    if rank == 0:

        time_per_epoch_seconds = [t.total_seconds() for t in time_per_epoch]
        avg_time_per_epoch = sum(time_per_epoch_seconds) / len(
            time_per_epoch_seconds)
        metrics = {
            "accuracy":
            best_acc1,
            "accuracy_per_epoch":
            accuracy_per_epoch,
            "avg_time_per_epoch_str":
            str(timedelta(seconds=int(avg_time_per_epoch))),
            "time_per_epoch":
            time_per_epoch_seconds,
        }

        logger.info(
            "\nNote:\n- 'total_time' includes the data loading time, training time and testing time.\n- 'time_per_epoch' measures the training time only.\n"
        )
        logger.info(metrics)

    if world_size > 1:
        cleanup()
Пример #4
0
def main():
    parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
    parser.add_argument(
        "-j",
        "--workers",
        default=2,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 2)",
    )
    parser.add_argument(
        "--epochs",
        default=100,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "--start-epoch",
        default=1,
        type=int,
        metavar="N",
        help="manual epoch number (useful on restarts)",
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        default=256,
        type=int,
        metavar="N",
        help="mini-batch size (default: 256), this is the total "
        "batch size of all GPUs on the current node when "
        "using Data Parallel or Distributed Data Parallel",
    )
    parser.add_argument(
        "-na",
        "--n_accumulation_steps",
        default=1,
        type=int,
        metavar="N",
        help="number of mini-batches to accumulate into an effective batch",
    )
    parser.add_argument(
        "--lr",
        "--learning-rate",
        default=0.001,
        type=float,
        metavar="LR",
        help="initial learning rate",
        dest="lr",
    )
    parser.add_argument("--momentum",
                        default=0.9,
                        type=float,
                        metavar="M",
                        help="SGD momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=5e-4,
        type=float,
        metavar="W",
        help="SGD weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument(
        "-p",
        "--print-freq",
        default=5,
        type=int,
        metavar="N",
        help="print frequency (default: 10)",
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "-e",
        "--evaluate",
        dest="evaluate",
        action="store_true",
        help="evaluate model on validation set",
    )
    parser.add_argument("--seed",
                        default=None,
                        type=int,
                        help="seed for initializing training. ")
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="GPU ID for this process (default: 'cuda')",
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=0.001,
        metavar="S",
        help="Noise multiplier (default 1.0)",
    )
    parser.add_argument(
        "-c",
        "--max-per-sample-grad_norm",
        type=float,
        default=100.0,
        metavar="C",
        help="Clip per-sample gradients to this norm (default 1.0)",
    )
    parser.add_argument(
        "--disable-dp",
        action="store_true",
        default=False,
        help="Disable privacy training and just train with vanilla SGD",
    )
    parser.add_argument(
        "--delta",
        type=float,
        default=1e-5,
        metavar="D",
        help="Target delta (default: 1e-5)",
    )

    parser.add_argument(
        "--checkpoint-file",
        type=str,
        default="checkpoint",
        help="path to save check points",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="../cifar10",
        help="Where CIFAR10 is/will be stored",
    )
    parser.add_argument("--log-dir",
                        type=str,
                        default="",
                        help="Where Tensorboard log will be stored")
    parser.add_argument(
        "--optim",
        type=str,
        default="Adam",
        help="Optimizer to use (Adam, RMSprop, SGD)",
    )
    parser.add_argument('--save_path',
                        type=str,
                        default='/content/drive/My Drive/resnet18')

    args = parser.parse_args()

    # The following few lines, enable stats gathering about the run
    # 1. where the stats should be logged
    stats.set_global_summary_writer(
        tensorboard.SummaryWriter(os.path.join("/tmp/stat", args.log_dir)))
    # 2. enable stats
    stats.add(
        # stats about gradient norms aggregated for all layers
        stats.Stat(stats.StatType.CLIPPING, "AllLayers", frequency=0.1),
        # stats about gradient norms per layer
        stats.Stat(stats.StatType.CLIPPING, "PerLayer", frequency=0.1),
        # stats about clipping
        stats.Stat(stats.StatType.CLIPPING, "ClippingStats", frequency=0.1),
        # stats on training accuracy
        stats.Stat(stats.StatType.TRAIN, "accuracy", frequency=0.01),
        # stats on validation accuracy
        stats.Stat(stats.StatType.TEST, "accuracy"),
    )

    # The following lines enable stat gathering for the clipping process
    # and set a default of per layer clipping for the Privacy Engine
    clipping = {"clip_per_layer": False, "enable_stat": True}
    augmentations = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
        # transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
    ]

    train_transform = transforms.Compose(
        augmentations + normalize if args.disable_dp else normalize)

    test_transform = transforms.Compose(normalize)

    train_dataset = CIFAR10(root=args.data_root,
                            train=True,
                            download=True,
                            transform=train_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.workers,
        drop_last=True,
    )

    test_dataset = CIFAR10(root=args.data_root,
                           train=False,
                           download=True,
                           transform=test_transform)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
    )

    best_acc1 = 0
    device = torch.device(args.device)
    model = convert_batchnorm_modules(models.resnet18(num_classes=10))
    model = model.to(device)

    if args.optim == "SGD":
        optimizer = optim.SGD(
            model.parameters(),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )
    elif args.optim == "RMSprop":
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr)
    elif args.optim == "Adam":
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError(
            "Optimizer not recognized. Please check spelling")

    if not args.disable_dp:
        privacy_engine = PrivacyEngine(
            model,
            batch_size=args.batch_size * args.n_accumulation_steps,
            sample_size=len(train_dataset),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=args.sigma,
            max_grad_norm=args.max_per_sample_grad_norm,
            **clipping,
        )
        privacy_engine.attach(optimizer)

    for epoch in range(args.start_epoch, args.epochs + 1):
        if not args.disable_dp:
            epsilon, best_alpha = train(args, model, train_loader, optimizer,
                                        epoch, device)
        else:
            train(args, model, train_loader, optimizer, epoch, device)
        top1_acc = test(args, model, test_loader, device)

        # remember best acc@1 and save checkpoint
        is_best = top1_acc > best_acc1
        best_acc1 = max(top1_acc, best_acc1)

        if not args.disable_dp and epoch % 5 == 0:
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'epoch': epoch,
                    'epsilon': epsilon,
                    'best_alpha': best_alpha,
                    'accuracy': top1_acc
                },
                os.path.join(args.save_path,
                             f"resnet18_cifar10_dp_{epoch}.tar"))

        # else:
        #     save_checkpoint(
        #         {
        #             "epoch": epoch,
        #             "arch": "ResNet18",
        #             "state_dict": model.state_dict(),
        #             "best_acc1": best_acc1,
        #             "optimizer": optimizer.state_dict(),
        #         },
        #         is_best,
        #         filename=args.checkpoint_file + ".tar",
        #     )

    if args.disable_dp:
        torch.save(model.state_dict(),
                   os.path.join(args.save_path, f'resnet18_cifar10.pt'))