Esempio n. 1
0
#!/usr/bin/env python

"""
Base class for Autoencoder frameworks.

Defines basic common methods and variables shared between models.
Each model overwrites as needed. 
This class inherits from torch.nn.Module, ensuring that network parameters
are registered properly. 
"""
import torch
import torch.nn as nn

#logging module with handmade settings.
from DiVAE import logging
logger = logging.getLogger(__name__)

from utils.helpers import OutputContainer

# Base Class for Autoencoder models
class AutoEncoderBase(nn.Module):
    def __init__(self, flat_input_size, train_ds_mean, activation_fct, cfg, **kwargs):
        """
        """

        super(AutoEncoderBase,self).__init__(**kwargs)
        #sanity checks
        if isinstance(flat_input_size,list):
            assert len(flat_input_size)>0, "Input dimension not defined, needed for model structure"
        else:
            assert flat_input_size>0, "Input dimension not defined, needed for model structure"
Esempio n. 2
0
"""
Variational Autoencoder Model with hierarchical encoder

Author: Eric Drechsler ([email protected])
"""
import torch
from torch import nn
from models.autoencoders.autoencoder import AutoEncoder

from models.networks.hierarchicalEncoder import HierarchicalEncoder

#logging module with handmade settings.
from DiVAE import logging
logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)


#VAE with a hierarchical posterior modelled by encoder
#samples drawn from gaussian
class HierarchicalVAE(AutoEncoder):
    def __init__(self, **kwargs):
        super(HierarchicalVAE, self).__init__(**kwargs)

        self._model_type = "HiVAE"

        self._reparamNodes = (self._config.model.n_encoder_layer_nodes,
                              self._latent_dimensions)

        self._decoder_nodes = []

        dec_hidden_node_list = list(self._config.model.decoder_hidden_nodes)