Exemple #1
0
    def test_encoder(self):
        np.random.seed(42)
        torch.manual_seed(42)
        skip_test = False
        try:
            enc = CNNEncoder()
        except NotImplementedError:
            skip_test = True

        if not skip_test:
            all_means, all_log_std = [], []
            for test_num in range(10):
                z_dim = np.random.randint(2,40)
                encoder = CNNEncoder(z_dim=z_dim)
                img = torch.randn(32, 1, 28, 28)
                mean, log_std = encoder(img)
                self.assertTrue((mean.shape[0] == 32 and mean.shape[1] == z_dim),
                                 msg="The shape of the mean output should be batch_size x z_dim")
                self.assertTrue((log_std.shape[0] == 32 and log_std.shape[1] == z_dim),
                                 msg="The shape of the log_std output should be batch_size x z_dim")
                all_means.append(mean.reshape(-1))
                all_log_std.append(log_std.reshape(-1))
            means = torch.cat(all_means, dim=0)
            log_std = torch.cat(all_log_std, dim=0)
            self.assertTrue((means > 0).any() and (means < 0).any(), msg="Only positive or only negative means detected. Are you sure this is what you want?")
            self.assertTrue((log_std > 0).any() and (log_std < 0).any(), msg="Only positive or only negative log-stds detected. Are you sure this is what you want?")
Exemple #2
0
 def __init__(self, model_name, hidden_dims, num_filters, z_dim, lr):
     """
     PyTorch Lightning module that summarizes all components to train a VAE.
     Inputs:
         model_name - String denoting what encoder/decoder class to use.  Either 'MLP' or 'CNN'
         hidden_dims - List of hidden dimensionalities to use in the MLP layers of the encoder (decoder reversed)
         num_filters - Number of channels to use in a CNN encoder/decoder
         z_dim - Dimensionality of latent space
         lr - Learning rate to use for the optimizer
     """
     super().__init__()
     self.save_hyperparameters()
     self.bce_loss = nn.BCEWithLogitsLoss()
     if model_name == 'MLP':
         self.encoder = MLPEncoder(z_dim=z_dim, hidden_dims=hidden_dims)
         self.decoder = MLPDecoder(z_dim=z_dim, hidden_dims=hidden_dims[::-1])
     else:
         self.encoder = CNNEncoder(z_dim=z_dim, num_filters=num_filters)
         self.decoder = CNNDecoder(z_dim=z_dim, num_filters=num_filters)
    def __init__(self, model_name, hidden_dims, num_filters, z_dim, *args,
                 **kwargs):
        """
        PyTorch module that summarizes all components to train a VAE.
        Inputs:
            model_name - String denoting what encoder/decoder class to use.  Either 'MLP' or 'CNN'
            hidden_dims - List of hidden dimensionalities to use in the MLP layers of the encoder (decoder reversed)
            num_filters - Number of channels to use in a CNN encoder/decoder
            z_dim - Dimensionality of latent space
        """
        super().__init__()
        self.z_dim = z_dim

        if model_name == 'MLP':
            self.encoder = MLPEncoder(z_dim=z_dim, hidden_dims=hidden_dims)
            self.decoder = MLPDecoder(z_dim=z_dim,
                                      hidden_dims=hidden_dims[::-1])
        else:
            self.encoder = CNNEncoder(z_dim=z_dim, num_filters=num_filters)
            self.decoder = CNNDecoder(z_dim=z_dim, num_filters=num_filters)