コード例 #1
0
def main():
    parser = ArgParser()
    args = parser.parse_args()

    gen = Generator(args.latent_dim).to(args.device)
    disc = Discriminator().to(args.device)
    if args.device != 'cpu':
        gen = nn.DataParallel(gen, args.gpu_ids)
        disc = nn.DataParallel(disc, args.gpu_ids)
    # gen = gen.apply(weights_init)
    # disc = disc.apply(weights_init)

    gen_opt = torch.optim.RMSprop(gen.parameters(), lr=args.lr)
    disc_opt = torch.optim.RMSprop(disc.parameters(), lr=args.lr)
    gen_scheduler = torch.optim.lr_scheduler.LambdaLR(gen_opt, lr_lambda=lr_lambda(args.num_epochs))
    disc_scheduler = torch.optim.lr_scheduler.LambdaLR(disc_opt, lr_lambda=lr_lambda(args.num_epochs))
    disc_loss_fn = DiscriminatorLoss().to(args.device)
    gen_loss_fn = GeneratorLoss().to(args.device)

    # dataset = Dataset()
    dataset = MNISTDataset()
    loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)

    logger = TrainLogger(args, len(loader), phase=None)
    logger.log_hparams(args)

    if args.privacy_noise_multiplier != 0:
        privacy_engine = PrivacyEngine(
            disc,
            batch_size=args.batch_size,
            sample_size=len(dataset),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=.8,
            max_grad_norm=0.02,
            batch_first=True,
        )
        privacy_engine.attach(disc_opt)
        privacy_engine.to(args.device)

    for epoch in range(args.num_epochs):
        logger.start_epoch()
        for cur_step, img in enumerate(tqdm(loader, dynamic_ncols=True)):
            logger.start_iter()
            img = img.to(args.device)
            fake, disc_loss = None, None
            for _ in range(args.step_train_discriminator):
                disc_opt.zero_grad()
                fake_noise = get_noise(args.batch_size, args.latent_dim, device=args.device)
                fake = gen(fake_noise)
                disc_loss = disc_loss_fn(img, fake, disc)
                disc_loss.backward()
                disc_opt.step()

            gen_opt.zero_grad()
            fake_noise_2 = get_noise(args.batch_size, args.latent_dim, device=args.device)
            fake_2 = gen(fake_noise_2)
            gen_loss = gen_loss_fn(img, fake_2, disc)
            gen_loss.backward()
            gen_opt.step()
            if args.privacy_noise_multiplier != 0:
                epsilon, best_alpha = privacy_engine.get_privacy_spent(args.privacy_delta)

            logger.log_iter_gan_from_latent_vector(img, fake, gen_loss, disc_loss, epsilon if args.privacy_noise_multiplier != 0 else 0)
            logger.end_iter()

        logger.end_epoch()
        gen_scheduler.step()
        disc_scheduler.step()
コード例 #2
0
class LitSampleConvNetClassifier(pl.LightningModule):
    def __init__(
        self,
        lr: float = 0.1,
        enable_dp: bool = True,
        delta: float = 1e-5,
        sample_rate: float = 0.001,
        sigma: float = 1.0,
        max_per_sample_grad_norm: float = 1.0,
        secure_rng: bool = False,
    ):
        """A simple conv-net for classifying MNIST with differential privacy

        Args:
            lr: Learning rate
            enable_dp: Enables training with privacy guarantees using Opacus (if True), vanilla SGD otherwise
            delta: Target delta for which (eps, delta)-DP is computed
            sample_rate: Sample rate used for batch construction
            sigma: Noise multiplier
            max_per_sample_grad_norm: Clip per-sample gradients to this norm
            secure_rng: Use secure random number generator
        """
        super().__init__()

        # Hyper-parameters
        self.lr = lr
        self.enable_dp = enable_dp
        self.delta = delta
        self.sample_rate = sample_rate
        self.sigma = sigma
        self.max_per_sample_grad_norm = max_per_sample_grad_norm
        self.secure_rng = secure_rng

        # Parameters
        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, 4, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 32)
        self.fc2 = nn.Linear(32, 10)

        # Privacy engine
        self.privacy_engine = None  # Created before training

        # Metrics
        self.test_accuracy = torchmetrics.Accuracy()

    def forward(self, x):
        # x of shape [B, 1, 28, 28]
        x = F.relu(self.conv1(x))  # -> [B, 16, 14, 14]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 16, 13, 13]
        x = F.relu(self.conv2(x))  # -> [B, 32, 5, 5]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 32, 4, 4]
        x = x.view(-1, 32 * 4 * 4)  # -> [B, 512]
        x = F.relu(self.fc1(x))  # -> [B, 32]
        x = self.fc2(x)  # -> [B, 10]
        return x

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=0)
        return optimizer

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.log("train_loss",
                 loss,
                 on_step=False,
                 on_epoch=True,
                 prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        self.test_accuracy(output, target)
        self.log("test_loss",
                 loss,
                 on_step=False,
                 on_epoch=True,
                 prog_bar=True)
        self.log("test_accuracy",
                 self.test_accuracy,
                 on_step=False,
                 on_epoch=True)
        return loss

    # Adding differential privacy learning

    def on_train_start(self) -> None:
        if self.enable_dp:
            self.privacy_engine = PrivacyEngine(
                self,
                sample_rate=self.sample_rate,
                alphas=[1 + x / 10.0
                        for x in range(1, 100)] + list(range(12, 64)),
                noise_multiplier=self.sigma,
                max_grad_norm=self.max_per_sample_grad_norm,
                secure_rng=self.secure_rng,
            )

            optimizer = self.optimizers()
            self.privacy_engine.attach(optimizer)

    def on_train_epoch_end(self):
        if self.enable_dp:
            epsilon, best_alpha = self.privacy_engine.get_privacy_spent(
                self.delta)
            # Privacy spent: (epsilon, delta) for alpha
            self.log("epsilon", epsilon, on_epoch=True, prog_bar=True)
            self.log("alpha", best_alpha, on_epoch=True, prog_bar=True)

    def on_train_end(self):
        if self.enable_dp:
            self.privacy_engine.detach()