Пример #1
0
    def __init__(
        self,
        hidden_dims_gen,
        hidden_dims_disc,
        dp_rate_gen,
        dp_rate_disc,
        z_dim,
        *args,
        **kwargs,
    ):
        """
        PyTorch Lightning module that summarizes all components to train a GAN.

        Inputs:
            hidden_dims_gen  - List of hidden dimensionalities to use in the
                              layers of the generator
            hidden_dims_disc - List of hidden dimensionalities to use in the
                               layers of the discriminator
            dp_rate_gen      - Dropout probability to use in the generator
            dp_rate_disc     - Dropout probability to use in the discriminator
            z_dim            - Dimensionality of latent space
        """
        super().__init__()
        self.z_dim = z_dim

        self.generator = GeneratorMLP(z_dim=z_dim,
                                      hidden_dims=hidden_dims_gen,
                                      dp_rate=dp_rate_gen)
        self.discriminator = DiscriminatorMLP(hidden_dims=hidden_dims_disc,
                                              dp_rate=dp_rate_disc)
Пример #2
0
 def test_shape(self):
     np.random.seed(42)
     torch.manual_seed(42)
     z_dim = 64
     gen = GeneratorMLP(z_dim=z_dim, output_shape=[1, 28, 28])
     z = torch.randn(4, z_dim)
     imgs = gen(z)
     self.assertTrue(
         len(imgs.shape) == 4
         and all([imgs.shape[i] == o
                  for i, o in enumerate([4, 1, 28, 28])]),
         msg=
         "The output of the generator should be an image with shape [B,C,H,W]."
     )
Пример #3
0
 def test_output_values(self):
     np.random.seed(42)
     torch.manual_seed(42)
     z_dim = 20
     gen = GeneratorMLP(z_dim=z_dim,
                        hidden_dims=[64],
                        output_shape=[1, 28, 28])
     z = torch.randn(128, z_dim) * 50
     imgs = gen(z)
     self.assertTrue((imgs >= -1).all() and (imgs <= 1).all(),
                     msg="The output of the generator should have values between -1 and 1. " + \
                         "A tanh as output activation function might be missing.")
     self.assertTrue((imgs < 0).any(),
                     msg="The output of the generator should have values between -1 and 1, " + \
                         "but seems to be missing negative values in your model.")
Пример #4
0
    def __init__(self, hidden_dims_gen, hidden_dims_disc, dp_rate_gen,
                 dp_rate_disc, z_dim, lr):
        """
        PyTorch Lightning module that summarizes all components to train a GAN.

        Inputs:
            hidden_dims_gen  - List of hidden dimensionalities to use in the
                              layers of the generator
            hidden_dims_disc - List of hidden dimensionalities to use in the
                               layers of the discriminator
            dp_rate_gen      - Dropout probability to use in the generator
            dp_rate_disc     - Dropout probability to use in the discriminator
            z_dim            - Dimensionality of latent space
            lr               - Learning rate to use for the optimizer
        """
        super().__init__()
        self.save_hyperparameters()

        self.generator = GeneratorMLP(z_dim=z_dim,
                                      hidden_dims=hidden_dims_gen,
                                      dp_rate=dp_rate_gen)
        self.discriminator = DiscriminatorMLP(hidden_dims=hidden_dims_disc,
                                              dp_rate=dp_rate_disc)