#!/usr/bin/env python """ Base class for Autoencoder frameworks. Defines basic common methods and variables shared between models. Each model overwrites as needed. This class inherits from torch.nn.Module, ensuring that network parameters are registered properly. """ import torch import torch.nn as nn #logging module with handmade settings. from DiVAE import logging logger = logging.getLogger(__name__) from utils.helpers import OutputContainer # Base Class for Autoencoder models class AutoEncoderBase(nn.Module): def __init__(self, flat_input_size, train_ds_mean, activation_fct, cfg, **kwargs): """ """ super(AutoEncoderBase,self).__init__(**kwargs) #sanity checks if isinstance(flat_input_size,list): assert len(flat_input_size)>0, "Input dimension not defined, needed for model structure" else: assert flat_input_size>0, "Input dimension not defined, needed for model structure"
""" Variational Autoencoder Model with hierarchical encoder Author: Eric Drechsler ([email protected]) """ import torch from torch import nn from models.autoencoders.autoencoder import AutoEncoder from models.networks.hierarchicalEncoder import HierarchicalEncoder #logging module with handmade settings. from DiVAE import logging logger = logging.getLogger(__name__) logging.getLogger().setLevel(logging.INFO) #VAE with a hierarchical posterior modelled by encoder #samples drawn from gaussian class HierarchicalVAE(AutoEncoder): def __init__(self, **kwargs): super(HierarchicalVAE, self).__init__(**kwargs) self._model_type = "HiVAE" self._reparamNodes = (self._config.model.n_encoder_layer_nodes, self._latent_dimensions) self._decoder_nodes = [] dec_hidden_node_list = list(self._config.model.decoder_hidden_nodes)