def test_privacy_engine_to_example(self):
     # IMPORTANT: When changing this code you also need to update
     # the docstring for opacus.privacy_engine.PrivacyEngine.to()
     model = torch.nn.Linear(16, 32)  # An example model. Default device is CPU
     privacy_engine = PrivacyEngine(
         model,
         sample_rate=0.01,
         noise_multiplier=0.8,
         max_grad_norm=0.5,
     )
     device = "cpu"
     model.to(
         device
     )  # If we move the model to GPU, we should call the to() method of the privacy engine (next line)
     privacy_engine.to(device)
Exemplo n.º 2
0
def main():
    parser = ArgParser()
    args = parser.parse_args()

    gen = Generator(args.latent_dim).to(args.device)
    disc = Discriminator().to(args.device)
    if args.device != 'cpu':
        gen = nn.DataParallel(gen, args.gpu_ids)
        disc = nn.DataParallel(disc, args.gpu_ids)
    # gen = gen.apply(weights_init)
    # disc = disc.apply(weights_init)

    gen_opt = torch.optim.RMSprop(gen.parameters(), lr=args.lr)
    disc_opt = torch.optim.RMSprop(disc.parameters(), lr=args.lr)
    gen_scheduler = torch.optim.lr_scheduler.LambdaLR(gen_opt, lr_lambda=lr_lambda(args.num_epochs))
    disc_scheduler = torch.optim.lr_scheduler.LambdaLR(disc_opt, lr_lambda=lr_lambda(args.num_epochs))
    disc_loss_fn = DiscriminatorLoss().to(args.device)
    gen_loss_fn = GeneratorLoss().to(args.device)

    # dataset = Dataset()
    dataset = MNISTDataset()
    loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)

    logger = TrainLogger(args, len(loader), phase=None)
    logger.log_hparams(args)

    if args.privacy_noise_multiplier != 0:
        privacy_engine = PrivacyEngine(
            disc,
            batch_size=args.batch_size,
            sample_size=len(dataset),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=.8,
            max_grad_norm=0.02,
            batch_first=True,
        )
        privacy_engine.attach(disc_opt)
        privacy_engine.to(args.device)

    for epoch in range(args.num_epochs):
        logger.start_epoch()
        for cur_step, img in enumerate(tqdm(loader, dynamic_ncols=True)):
            logger.start_iter()
            img = img.to(args.device)
            fake, disc_loss = None, None
            for _ in range(args.step_train_discriminator):
                disc_opt.zero_grad()
                fake_noise = get_noise(args.batch_size, args.latent_dim, device=args.device)
                fake = gen(fake_noise)
                disc_loss = disc_loss_fn(img, fake, disc)
                disc_loss.backward()
                disc_opt.step()

            gen_opt.zero_grad()
            fake_noise_2 = get_noise(args.batch_size, args.latent_dim, device=args.device)
            fake_2 = gen(fake_noise_2)
            gen_loss = gen_loss_fn(img, fake_2, disc)
            gen_loss.backward()
            gen_opt.step()
            if args.privacy_noise_multiplier != 0:
                epsilon, best_alpha = privacy_engine.get_privacy_spent(args.privacy_delta)

            logger.log_iter_gan_from_latent_vector(img, fake, gen_loss, disc_loss, epsilon if args.privacy_noise_multiplier != 0 else 0)
            logger.end_iter()

        logger.end_epoch()
        gen_scheduler.step()
        disc_scheduler.step()