def __init__( self, hidden_dims_gen, hidden_dims_disc, dp_rate_gen, dp_rate_disc, z_dim, *args, **kwargs, ): """ PyTorch Lightning module that summarizes all components to train a GAN. Inputs: hidden_dims_gen - List of hidden dimensionalities to use in the layers of the generator hidden_dims_disc - List of hidden dimensionalities to use in the layers of the discriminator dp_rate_gen - Dropout probability to use in the generator dp_rate_disc - Dropout probability to use in the discriminator z_dim - Dimensionality of latent space """ super().__init__() self.z_dim = z_dim self.generator = GeneratorMLP(z_dim=z_dim, hidden_dims=hidden_dims_gen, dp_rate=dp_rate_gen) self.discriminator = DiscriminatorMLP(hidden_dims=hidden_dims_disc, dp_rate=dp_rate_disc)
def test_output_values(self): np.random.seed(42) torch.manual_seed(42) disc = DiscriminatorMLP(input_dims=784) z = torch.randn(128, 784) preds = disc(z) self.assertTrue((preds < 0).any(), msg="The output of the discriminator does not have any negative values. " + "You might be applying a sigmoid on the discriminator output. " + \ "It is recommended to work on logits instead as this is numercially more stable. " + \ "Ensure that you are using the correct loss accordings (BCEWithLogits instead of BCE).")
def __init__(self, hidden_dims_gen, hidden_dims_disc, dp_rate_gen, dp_rate_disc, z_dim, lr): """ PyTorch Lightning module that summarizes all components to train a GAN. Inputs: hidden_dims_gen - List of hidden dimensionalities to use in the layers of the generator hidden_dims_disc - List of hidden dimensionalities to use in the layers of the discriminator dp_rate_gen - Dropout probability to use in the generator dp_rate_disc - Dropout probability to use in the discriminator z_dim - Dimensionality of latent space lr - Learning rate to use for the optimizer """ super().__init__() self.save_hyperparameters() self.generator = GeneratorMLP(z_dim=z_dim, hidden_dims=hidden_dims_gen, dp_rate=dp_rate_gen) self.discriminator = DiscriminatorMLP(hidden_dims=hidden_dims_disc, dp_rate=dp_rate_disc)
class GAN(pl.LightningModule): def __init__(self, hidden_dims_gen, hidden_dims_disc, dp_rate_gen, dp_rate_disc, z_dim, lr): """ PyTorch Lightning module that summarizes all components to train a GAN. Inputs: hidden_dims_gen - List of hidden dimensionalities to use in the layers of the generator hidden_dims_disc - List of hidden dimensionalities to use in the layers of the discriminator dp_rate_gen - Dropout probability to use in the generator dp_rate_disc - Dropout probability to use in the discriminator z_dim - Dimensionality of latent space lr - Learning rate to use for the optimizer """ super().__init__() self.save_hyperparameters() self.generator = GeneratorMLP(z_dim=z_dim, hidden_dims=hidden_dims_gen, dp_rate=dp_rate_gen) self.discriminator = DiscriminatorMLP(hidden_dims=hidden_dims_disc, dp_rate=dp_rate_disc) @torch.no_grad() def sample(self, batch_size): """ Function for sampling a new batch of random images from the generator. Inputs: batch_size - Number of images to generate Outputs: x - Generated images of shape [B,C,H,W] """ z = torch.randn(batch_size, self.hparams.z_dim) x = self.generator(z) return x @torch.no_grad() def interpolate(self, batch_size, interpolation_steps): """ Function for interpolating between a batch of pairs of randomly sampled images. The interpolation is performed on the latent input space of the generator. Inputs: batch_size - Number of image pairs to generate interpolation_steps - Number of intermediate interpolation points that should be generated. Outputs: x - Generated images of shape [B,interpolation_steps+2,C,H,W] """ z_pairs = torch.randn(batch_size, 2, self.hparams.z_dim) lambdas = torch.linspace(0, 1, interpolation_steps + 2) # z has shape [batch_size, interpolation_steps + 2, z_dim] z = lambdas[None, :, None] * z_pairs[:, 0, None] + ( 1 - lambdas[None, :, None]) * z_pairs[:, 1, None] x = self.generator(z.flatten(end_dim=1)) x = x.view(batch_size, interpolation_steps + 2, *self.generator.output_shape) return x def configure_optimizers(self): # Create optimizer for both generator and discriminator. # You can use the Adam optimizer for both models. # It is recommended to reduce the momentum (beta1) to e.g. 0.5 optimizer_gen = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.lr, betas=(0.5, 0.999)) optimizer_disc = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr, betas=(0.5, 0.999)) return [optimizer_gen, optimizer_disc], [] def training_step(self, batch, batch_idx, optimizer_idx): """ The training step is called for every optimizer independently. As we have two optimizers (Generator and Discriminator), we therefore have two training steps for the same batch with different optimizer_idx. The loss returned for the optimizer_idx=0 should correspond to the loss of the first optimizer returned in configure_optimizers (i.e. by the generator). The second time the function is called (optimizer_idx=1), we optimize the discriminator. See the individual step functions "generator_step" and "discriminator_step" for their specific loss calculation. Inputs: batch - Input batch from MNIST dataset batch_idx - Index of the batch in the dataset (not needed here) optimizer_idx - Index of the optimizer to use for a specific training step - 0 = Generator, 1 = Discriminator """ x, _ = batch if optimizer_idx == 0: loss = self.generator_step(x) elif optimizer_idx == 1: loss = self.discriminator_step(x) return loss def generator_step(self, x_real): """ Training step for the generator. Note that you do *not* need to take any special care of the discriminator in terms of stopping the gradients to its parameters, as this is handled by having two different optimizers. Before the discriminator's gradients in its own step are calculated, the previous ones are set to zero by PyTorch Lightning's internal training loop. Remember to log the training loss. Inputs: x_real - Batch of images from the dataset Outputs: loss - The loss for the generator to optimize """ z = torch.randn(x_real.size(0), self.hparams.z_dim) x_fake = self.generator(z) pred_fake = self.discriminator(x_fake) # We use the non-saturating loss rather than the negative discriminator loss loss = -F.logsigmoid(pred_fake).mean() self.log("generator/loss", loss) return loss def discriminator_step(self, x_real): """ Training step for the discriminator. Note that you do not have to use the same generated images as in the generator_step. It is simpler to sample a new batch of "fake" images, and use those for training the discriminator. Remember to log the training loss, and other potentially interesting metrics. Inputs: x_real - Batch of images from the dataset Outputs: loss - The loss for the discriminator to optimize """ z = torch.randn(x_real.size(0), self.hparams.z_dim) x_fake = self.generator(z) pred_fake = self.discriminator(x_fake) pred_real = self.discriminator(x_real) loss = -F.logsigmoid(pred_real).mean() - torch.log( 1 - torch.sigmoid(pred_fake)).mean() accuracy_real = (pred_real > 0).float().mean() accuracy_fake = (pred_fake < 0).float().mean() # Remark: there are more metrics that you can add. # For instance, how about the accuracy of the discriminator? self.log("discriminator/loss", loss) self.log("discriminator/accuracy_real", accuracy_real) self.log("discriminator/accuracy_fake", accuracy_fake) return loss