Beispiel #1
0
    def __init__(
        self,
        hidden_dims_gen,
        hidden_dims_disc,
        dp_rate_gen,
        dp_rate_disc,
        z_dim,
        *args,
        **kwargs,
    ):
        """
        PyTorch Lightning module that summarizes all components to train a GAN.

        Inputs:
            hidden_dims_gen  - List of hidden dimensionalities to use in the
                              layers of the generator
            hidden_dims_disc - List of hidden dimensionalities to use in the
                               layers of the discriminator
            dp_rate_gen      - Dropout probability to use in the generator
            dp_rate_disc     - Dropout probability to use in the discriminator
            z_dim            - Dimensionality of latent space
        """
        super().__init__()
        self.z_dim = z_dim

        self.generator = GeneratorMLP(z_dim=z_dim,
                                      hidden_dims=hidden_dims_gen,
                                      dp_rate=dp_rate_gen)
        self.discriminator = DiscriminatorMLP(hidden_dims=hidden_dims_disc,
                                              dp_rate=dp_rate_disc)
Beispiel #2
0
 def test_output_values(self):
     np.random.seed(42)
     torch.manual_seed(42)
     disc = DiscriminatorMLP(input_dims=784)
     z = torch.randn(128, 784)
     preds = disc(z)
     self.assertTrue((preds < 0).any(),
                     msg="The output of the discriminator does not have any negative values. " +
                         "You might be applying a sigmoid on the discriminator output. " + \
                         "It is recommended to work on logits instead as this is numercially more stable. " + \
                         "Ensure that you are using the correct loss accordings (BCEWithLogits instead of BCE).")
    def __init__(self, hidden_dims_gen, hidden_dims_disc, dp_rate_gen,
                 dp_rate_disc, z_dim, lr):
        """
        PyTorch Lightning module that summarizes all components to train a GAN.

        Inputs:
            hidden_dims_gen  - List of hidden dimensionalities to use in the
                              layers of the generator
            hidden_dims_disc - List of hidden dimensionalities to use in the
                               layers of the discriminator
            dp_rate_gen      - Dropout probability to use in the generator
            dp_rate_disc     - Dropout probability to use in the discriminator
            z_dim            - Dimensionality of latent space
            lr               - Learning rate to use for the optimizer
        """
        super().__init__()
        self.save_hyperparameters()

        self.generator = GeneratorMLP(z_dim=z_dim,
                                      hidden_dims=hidden_dims_gen,
                                      dp_rate=dp_rate_gen)
        self.discriminator = DiscriminatorMLP(hidden_dims=hidden_dims_disc,
                                              dp_rate=dp_rate_disc)
class GAN(pl.LightningModule):
    def __init__(self, hidden_dims_gen, hidden_dims_disc, dp_rate_gen,
                 dp_rate_disc, z_dim, lr):
        """
        PyTorch Lightning module that summarizes all components to train a GAN.

        Inputs:
            hidden_dims_gen  - List of hidden dimensionalities to use in the
                              layers of the generator
            hidden_dims_disc - List of hidden dimensionalities to use in the
                               layers of the discriminator
            dp_rate_gen      - Dropout probability to use in the generator
            dp_rate_disc     - Dropout probability to use in the discriminator
            z_dim            - Dimensionality of latent space
            lr               - Learning rate to use for the optimizer
        """
        super().__init__()
        self.save_hyperparameters()

        self.generator = GeneratorMLP(z_dim=z_dim,
                                      hidden_dims=hidden_dims_gen,
                                      dp_rate=dp_rate_gen)
        self.discriminator = DiscriminatorMLP(hidden_dims=hidden_dims_disc,
                                              dp_rate=dp_rate_disc)

    @torch.no_grad()
    def sample(self, batch_size):
        """
        Function for sampling a new batch of random images from the generator.

        Inputs:
            batch_size - Number of images to generate
        Outputs:
            x - Generated images of shape [B,C,H,W]
        """
        z = torch.randn(batch_size, self.hparams.z_dim)
        x = self.generator(z)
        return x

    @torch.no_grad()
    def interpolate(self, batch_size, interpolation_steps):
        """
        Function for interpolating between a batch of pairs of randomly sampled
        images. The interpolation is performed on the latent input space of the
        generator.

        Inputs:
            batch_size          - Number of image pairs to generate
            interpolation_steps - Number of intermediate interpolation points
                                  that should be generated.
        Outputs:
            x - Generated images of shape [B,interpolation_steps+2,C,H,W]
        """
        z_pairs = torch.randn(batch_size, 2, self.hparams.z_dim)
        lambdas = torch.linspace(0, 1, interpolation_steps + 2)
        # z has shape [batch_size, interpolation_steps + 2, z_dim]
        z = lambdas[None, :, None] * z_pairs[:, 0, None] + (
            1 - lambdas[None, :, None]) * z_pairs[:, 1, None]
        x = self.generator(z.flatten(end_dim=1))
        x = x.view(batch_size, interpolation_steps + 2,
                   *self.generator.output_shape)

        return x

    def configure_optimizers(self):
        # Create optimizer for both generator and discriminator.
        # You can use the Adam optimizer for both models.
        # It is recommended to reduce the momentum (beta1) to e.g. 0.5
        optimizer_gen = torch.optim.Adam(self.generator.parameters(),
                                         lr=self.hparams.lr,
                                         betas=(0.5, 0.999))
        optimizer_disc = torch.optim.Adam(self.discriminator.parameters(),
                                          lr=self.hparams.lr,
                                          betas=(0.5, 0.999))
        return [optimizer_gen, optimizer_disc], []

    def training_step(self, batch, batch_idx, optimizer_idx):
        """
        The training step is called for every optimizer independently. As we
        have two optimizers (Generator and Discriminator), we therefore have
        two training steps for the same batch with different optimizer_idx.
        The loss returned for the optimizer_idx=0 should correspond to the loss
        of the first optimizer returned in configure_optimizers (i.e. by the
        generator). The second time the function is called (optimizer_idx=1),
        we optimize the discriminator. See the individual step functions
        "generator_step" and "discriminator_step" for their specific loss
        calculation.

        Inputs:
            batch         - Input batch from MNIST dataset
            batch_idx     - Index of the batch in the dataset (not needed here)
            optimizer_idx - Index of the optimizer to use for a specific
            training step - 0 = Generator, 1 = Discriminator
        """
        x, _ = batch

        if optimizer_idx == 0:
            loss = self.generator_step(x)
        elif optimizer_idx == 1:
            loss = self.discriminator_step(x)

        return loss

    def generator_step(self, x_real):
        """
        Training step for the generator. Note that you do *not* need to take
        any special care of the discriminator in terms of stopping the
        gradients to its parameters, as this is handled by having two different
        optimizers. Before the discriminator's gradients in its own step are
        calculated, the previous ones are set to zero by PyTorch Lightning's
        internal training loop. Remember to log the training loss.

        Inputs:
            x_real - Batch of images from the dataset
        Outputs:
            loss - The loss for the generator to optimize
        """

        z = torch.randn(x_real.size(0), self.hparams.z_dim)
        x_fake = self.generator(z)
        pred_fake = self.discriminator(x_fake)

        # We use the non-saturating loss rather than the negative discriminator loss
        loss = -F.logsigmoid(pred_fake).mean()
        self.log("generator/loss", loss)

        return loss

    def discriminator_step(self, x_real):
        """
        Training step for the discriminator. Note that you do not have to use
        the same generated images as in the generator_step. It is simpler to
        sample a new batch of "fake" images, and use those for training the
        discriminator. Remember to log the training loss, and other potentially
        interesting metrics.

        Inputs:
            x_real - Batch of images from the dataset
        Outputs:
            loss - The loss for the discriminator to optimize
        """
        z = torch.randn(x_real.size(0), self.hparams.z_dim)
        x_fake = self.generator(z)
        pred_fake = self.discriminator(x_fake)
        pred_real = self.discriminator(x_real)

        loss = -F.logsigmoid(pred_real).mean() - torch.log(
            1 - torch.sigmoid(pred_fake)).mean()
        accuracy_real = (pred_real > 0).float().mean()
        accuracy_fake = (pred_fake < 0).float().mean()

        # Remark: there are more metrics that you can add.
        # For instance, how about the accuracy of the discriminator?
        self.log("discriminator/loss", loss)
        self.log("discriminator/accuracy_real", accuracy_real)
        self.log("discriminator/accuracy_fake", accuracy_fake)

        return loss