Example #1
0
def autoencoder(x, dropout_rate, dropout, config):
    outputs = {}

    with tf.variable_scope('Encoder'):
        encoder = build_unified_encoder(x.get_shape().as_list(), config.intermediateResolutions)

        temp_out = x
        for layer in encoder:
            temp_out = layer(temp_out)

    with tf.variable_scope("Bottleneck"):
        intermediate_conv = Conv2D(temp_out.get_shape().as_list()[3] // 8, 1, padding='same')
        intermediate_conv_reverse = Conv2D(temp_out.get_shape().as_list()[3], 1, padding='same')
        dropout_layer = Dropout(dropout_rate)
        temp_out = intermediate_conv(temp_out)

        reshape = temp_out.get_shape().as_list()[1:]
        z_layer = Dense(config.zDim)
        dec_dense = Dense(np.prod(reshape))

        outputs['z'] = z = dropout_layer(z_layer(Flatten()(temp_out)), dropout)
        temp_out = intermediate_conv_reverse(tf.reshape(dropout_layer(dec_dense(z)), [-1, *reshape]))

    with tf.variable_scope('Decoder'):
        decoder = build_unified_decoder(config.outputWidth, config.intermediateResolutions, config.numChannels)
        # Decode: z -> x_hat
        for layer in decoder:
            temp_out = layer(temp_out)

        outputs['x_hat'] = temp_out

    return outputs
Example #2
0
def autoencoder_spatial(x, dropout_rate, dropout, config):
    outputs = {}

    with tf.variable_scope('Encoder'):
        encoder = build_unified_encoder(x.get_shape().as_list(), config.intermediateResolutions)
        dropout_layer = Dropout(dropout_rate)
        temp_out = x
        for layer in encoder:
            temp_out = layer(temp_out)
        temp_out = dropout_layer(temp_out, training=dropout)
    outputs['z'] = temp_out

    with tf.variable_scope('Decoder'):
        decoder = build_unified_decoder(config.outputWidth, config.intermediateResolutions, config.numChannels)
        # Decode: z -> x_hat
        for layer in decoder:
            temp_out = layer(temp_out)

        outputs['x_hat'] = temp_out

    return outputs
Example #3
0
def variational_autoencoder(x, dropout_rate, dropout, config):
    outputs = {}

    with tf.variable_scope('Encoder'):
        encoder = build_unified_encoder(x.get_shape().as_list(), config.intermediateResolutions)

        temp_out = x
        for layer in encoder:
            temp_out = layer(temp_out)

    with tf.variable_scope("Bottleneck"):
        intermediate_conv = Conv2D(temp_out.get_shape().as_list()[3] // 8, 1, padding='same')
        intermediate_conv_reverse = Conv2D(temp_out.get_shape().as_list()[3], 1, padding='same')
        dropout_layer = Dropout(dropout_rate)
        temp_out = intermediate_conv(temp_out)
        reshape = temp_out.get_shape().as_list()[1:]

        mu_layer = Dense(config.zDim)
        sigma_layer = Dense(config.zDim)
        dec_dense = Dense(np.prod(reshape))

        flatten = Flatten()(temp_out)
        outputs['z_mu'] = z_mu = dropout_layer(mu_layer(flatten), dropout)
        outputs['z_log_sigma'] = z_log_sigma = dropout_layer(sigma_layer(flatten), dropout)
        outputs['z_sigma'] = z_sigma = tf.exp(z_log_sigma)
        z_vae = z_mu + tf.random_normal(tf.shape(z_sigma)) * z_sigma
        reshaped = tf.reshape(dropout_layer(dec_dense(z_vae), dropout), [-1, *reshape])
        temp_out = intermediate_conv_reverse(reshaped)

    with tf.variable_scope('Decoder'):
        decoder = build_unified_decoder(config.outputWidth, config.intermediateResolutions, config.numChannels)

        # Decode: z -> x_hat
        for layer in decoder:
            temp_out = layer(temp_out)

        outputs['x_hat'] = temp_out

    return outputs
Example #4
0
File: fanogan.py Project: irfixq/AE
def fanogan(z, x, dropout_rate, dropout, config):
    outputs = {}

    # Encoder
    with tf.variable_scope('Encoder'):
        encoder = build_unified_encoder(x.get_shape().as_list(),
                                        config.intermediateResolutions)

        temp_out = x
        for layer in encoder:
            temp_out = layer(temp_out)

        temp_temp_out = temp_out
        intermediate_conv = Conv2D(temp_temp_out.get_shape().as_list()[3] // 8,
                                   1,
                                   padding='same')
        dropout_layer = Dropout(dropout_rate)
        temp_out = intermediate_conv(temp_out)

        reshape = temp_out.get_shape().as_list()[1:]
        z_layer = Dense(config.zDim)
        outputs['z_enc'] = z_enc = tf.nn.tanh(
            dropout_layer(z_layer(Flatten()(temp_out)), dropout))

    # Generator
    with tf.variable_scope('Generator'):
        intermediate_conv_reverse = Conv2D(
            temp_temp_out.get_shape().as_list()[3], 1, padding='same')
        dec_dense = Dense(np.prod(reshape))
        generator = build_unified_decoder(config.outputWidth,
                                          config.intermediateResolutions,
                                          config.numChannels,
                                          use_batchnorm=False)

        temp_out_z_enc = intermediate_conv_reverse(
            tf.reshape(dropout_layer(dec_dense(z_enc), dropout),
                       [-1, *reshape]))
        # encoder training:
        for layer in generator:
            temp_out_z_enc = layer(temp_out_z_enc)
        outputs['x_enc'] = x_enc = sigmoid(temp_out_z_enc)  # recon_img
        # generator training
        temp_out = intermediate_conv_reverse(
            tf.reshape(dropout_layer(dec_dense(z), dropout), [-1, *reshape]))
        for layer in generator:
            temp_out = layer(temp_out)
        outputs['x_'] = x_ = sigmoid(temp_out)

    # Discriminator
    with tf.variable_scope('Discriminator'):
        discriminator = build_unified_encoder(x_.get_shape().as_list(),
                                              config.intermediateResolutions,
                                              use_batchnorm=False)
        discriminator_dense = Dense(1)

        # fake:
        temp_out = x_
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_fake_features'] = temp_out
        outputs['d_'] = discriminator_dense(temp_out)

        # real:
        temp_out = x
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_features'] = temp_out  # image_features
        outputs['d'] = discriminator_dense(temp_out)

        alpha = tf.random_uniform(shape=[config.batchsize, 1],
                                  minval=0.,
                                  maxval=1.)  # eps
        diff = tf.reshape(
            (x_ - x), [config.batchsize,
                       np.prod(x.get_shape().as_list()[1:])])
        outputs['x_hat'] = x_hat = x + tf.reshape(
            alpha * diff, [config.batchsize, *x.get_shape().as_list()[1:]])

        temp_out = x_hat
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_hat_features'] = temp_out
        outputs['d_hat'] = discriminator_dense(temp_out)

        # encoder training:
        temp_out = x_enc
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_enc_features'] = temp_out  # recon_features
        outputs['d_enc'] = discriminator_dense(temp_out)

    return outputs
Example #5
0
def anovaegan(x, dropout_rate, dropout, config):
    outputs = {}

    # Encoder
    with tf.variable_scope('Encoder'):
        encoder = build_unified_encoder(x.get_shape().as_list(), config.intermediateResolutions)

        temp_out = x
        for layer in encoder:
            temp_out = layer(temp_out)

        temp_temp_out = temp_out
        intermediate_conv = Conv2D(temp_temp_out.get_shape().as_list()[3] // 8, 1, padding='same')
        dropout_layer = Dropout(dropout_rate)
        temp_out = intermediate_conv(temp_out)

        reshape = temp_out.get_shape().as_list()[1:]
        mu_layer = Dense(config.zDim)
        sigma_layer = Dense(config.zDim)

        flatten = Flatten()(temp_out)
        outputs['z_mu'] = z_mu = dropout_layer(mu_layer(flatten), dropout)
        outputs['z_log_sigma'] = z_log_sigma = dropout_layer(sigma_layer(flatten), dropout)
        outputs['z_sigma'] = z_sigma = tf.exp(z_log_sigma)
        z_vae = z_mu + tf.random_normal(tf.shape(z_sigma)) * z_sigma

    with tf.variable_scope("Generator"):
        intermediate_conv_reverse = Conv2D(temp_temp_out.get_shape().as_list()[3], 1, padding='same')
        dec_dense = Dense(np.prod(reshape))
        decoder = build_unified_decoder(outputWidth=config.outputWidth, intermediateResolutions=config.intermediateResolutions,
                                        outputChannels=config.numChannels,
                                        use_batchnorm=False)

        reshaped = tf.reshape(dropout_layer(dec_dense(z_vae)), [-1, *reshape])
        temp_out = intermediate_conv_reverse(reshaped)

        # Decode: z -> x_hat
        for layer in decoder:
            temp_out = layer(temp_out)

        outputs['out'] = temp_out

    # Discriminator
    with tf.variable_scope('Discriminator'):
        discriminator = build_unified_encoder(temp_out.get_shape().as_list(), intermediateResolutions=config.intermediateResolutions, use_batchnorm=False)
        discriminator_dense = Dense(1)

        # fake/reconstructed:
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_fake_features'] = temp_out
        outputs['d_'] = discriminator_dense(temp_out)

        # real:
        temp_out = x
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_features'] = temp_out  # image_features
        outputs['d'] = discriminator_dense(temp_out)

        # for GP
        alpha = tf.random_uniform(shape=[config.batchsize, 1], minval=0., maxval=1.)  # eps
        diff = tf.reshape((outputs['out'] - x), [config.batchsize, np.prod(x.get_shape().as_list()[1:])])
        outputs['x_hat'] = x_hat = x + tf.reshape(alpha * diff, [config.batchsize, *x.get_shape().as_list()[1:]])

        temp_out = x_hat
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_hat_features'] = temp_out
        outputs['d_hat'] = discriminator_dense(temp_out)
    return outputs
Example #6
0
def fanogan_schlegl(z, x, dropout_rate, dropout, config):
    outputs = {}
    dim = 64
    # Encoder
    with tf.variable_scope('Encoder'):
        encoder = build_unified_encoder(x.get_shape().as_list(), intermediateResolutions=config.intermediateResolutions)
        enc_dense = Dense(config.zDim)

        temp_out = x
        for layer in encoder:
            temp_out = layer(temp_out)
        outputs['z_enc'] = z_enc = tf.nn.tanh(enc_dense(Flatten()(temp_out)))  # restricting encoder outputs to range [-1;1]

    # Generator
    with tf.variable_scope('Generator'):
        generator = Bunch({
            # Model definition
            'gen_1': Dense(np.prod(config.intermediateResolutions) * 8 * dim),

            'gen_res1_conv1': Conv2D(filters=8 * dim, kernel_size=3, padding='same'),
            'gen_res1_layernorm1': LayerNormalization([1, 2]),
            'gen_res1_conv2': Conv2DTranspose(filters=8 * dim, kernel_size=3, padding='same'),
            'gen_res1_layernorm2': LayerNormalization([1, 2]),

            'gen_res2_conv1': Conv2D(filters=4 * dim, kernel_size=3, padding='same'),
            'gen_res2_layernorm1': LayerNormalization([1, 2]),
            'gen_res2_conv2': Conv2DTranspose(filters=4 * dim, kernel_size=3, strides=2, padding='same'),
            'gen_res2_layernorm2': LayerNormalization([1, 2]),
            'gen_res2_shortcut': Conv2DTranspose(filters=4 * dim, kernel_size=1, padding='same', strides=2),

            'gen_res3_conv1': Conv2D(filters=2 * dim, kernel_size=3, padding='same'),
            'gen_res3_layernorm1': LayerNormalization([1, 2]),
            'gen_res3_conv2': Conv2DTranspose(filters=2 * dim, kernel_size=3, strides=2, padding='same'),
            'gen_res3_layernorm2': LayerNormalization([1, 2]),
            'gen_res3_shortcut': Conv2DTranspose(filters=2 * dim, kernel_size=1, padding='same', strides=2),

            'gen_res4_conv1': Conv2D(filters=1 * dim, kernel_size=3, padding='same'),
            'gen_res4_layernorm1': LayerNormalization([1, 2]),
            'gen_res4_conv2': Conv2DTranspose(filters=1 * dim, kernel_size=3, strides=2, padding='same'),
            'gen_res4_layernorm2': LayerNormalization([1, 2]),
            'gen_res4_shortcut': Conv2DTranspose(filters=1 * dim, kernel_size=1, padding='same', strides=2),

            # post process
            'gen_layernorm': LayerNormalization([1, 2]),
            'gen_conv': Conv2D(1, 1, padding='same', activation='tanh')
        })

        outputs['x_'] = x_ = evaluate_generator(generator, z, config.intermediateResolutions, dim)

        # encoder training:
        outputs['x_enc'] = x_enc = evaluate_generator(generator, z_enc, config.intermediateResolutions, dim)

    # Discriminator
    with tf.variable_scope('Discriminator'):
        discriminator = Bunch({
            # Model definition
            'dis_conv': Conv2D(dim, 3, padding='same'),

            'dis_res1_conv1': Conv2D(filters=2 * dim, kernel_size=3, padding='same'),
            'dis_res1_layernorm1': LayerNormalization([1, 2]),
            'dis_res1_conv2': Conv2D(filters=2 * dim, kernel_size=3, strides=2, padding='same'),
            'dis_res1_layernorm2': LayerNormalization([1, 2]),
            'dis_res1_shortcut1': Conv2D(filters=2 * dim, kernel_size=1, padding='same'),
            'dis_res1_shortcut2': AvgPool2D(),

            'dis_res2_conv1': Conv2D(filters=4 * dim, kernel_size=3, padding='same'),
            'dis_res2_layernorm1': LayerNormalization([1, 2]),
            'dis_res2_conv2': Conv2D(filters=4 * dim, kernel_size=3, strides=2, padding='same'),
            'dis_res2_layernorm2': LayerNormalization([1, 2]),
            'dis_res2_shortcut1': Conv2D(filters=4 * dim, kernel_size=1, padding='same'),
            'dis_res2_shortcut2': AvgPool2D(),

            'dis_res3_conv1': Conv2D(filters=8 * dim, kernel_size=3, padding='same'),
            'dis_res3_layernorm1': LayerNormalization([1, 2]),
            'dis_res3_conv2': Conv2D(filters=8 * dim, kernel_size=3, strides=2, padding='same'),
            'dis_res3_layernorm2': LayerNormalization([1, 2]),
            'dis_res3_shortcut1': Conv2D(filters=8 * dim, kernel_size=1, padding='same'),
            'dis_res3_shortcut2': AvgPool2D(),

            'dis_res4_conv1': Conv2D(filters=8 * dim, kernel_size=3, padding='same'),
            'dis_res4_layernorm1': LayerNormalization([1, 2]),
            'dis_res4_conv2': Conv2D(filters=8 * dim, kernel_size=3, padding='same'),
            'dis_res4_layernorm2': LayerNormalization([1, 2]),

            # post process
            # 'dis_flatten': Flatten(),
            'dis_dense': Dense(1),
        })

        # fake:
        outputs['d_fake_features'], outputs['d_'] = evaluate_discriminator(discriminator, x_)

        # real:
        outputs['d_features'], outputs['d'] = evaluate_discriminator(discriminator, x)

        # add noise
        alpha = tf.random_uniform(shape=[config.batchsize, 1], minval=0., maxval=1.)  # eps
        diff = tf.reshape((x_ - x), [config.batchsize, np.prod(x.get_shape().as_list()[1:])])
        outputs['x_hat'] = x_hat = x + tf.reshape(alpha * diff, [config.batchsize, *x.get_shape().as_list()[1:]])

        outputs['d_hat_features'], outputs['d_hat'] = evaluate_discriminator(discriminator, x_hat)

        # encoder training:
        outputs['d_enc_features'], outputs['d_enc'] = evaluate_discriminator(discriminator, x_enc)

    return outputs
def gaussian_mixture_variational_autoencoder_spatial(x, dropout_rate, dropout,
                                                     config):
    outputs = {}

    # encoding network q(z|x) and q(w|x)
    encoder = build_unified_encoder(x.get_shape().as_list(),
                                    config.intermediateResolutions)

    w_mu_layer = Conv2D(filters=config.dim_w,
                        kernel_size=1,
                        strides=1,
                        padding='same',
                        name='q_wz_x/w_mu')
    w_log_sigma_layer = Conv2D(filters=config.dim_w,
                               kernel_size=1,
                               strides=1,
                               padding='same',
                               name='q_wz_x/w_log_sigma')

    z_mu_layer = Conv2D(filters=config.dim_z,
                        kernel_size=1,
                        strides=1,
                        padding='same',
                        name='q_wz_x/z_mu')
    z_log_sigma_layer = Conv2D(filters=config.dim_z,
                               kernel_size=1,
                               strides=1,
                               padding='same',
                               name='q_wz_x/z_log_sigma')

    temp_out = x
    for layer in encoder:
        temp_out = layer(temp_out)

    outputs['w_mu'] = w_mu = w_mu_layer(temp_out)
    outputs['w_log_sigma'] = w_log_sigma = w_log_sigma_layer(temp_out)
    # reparametrization
    outputs['w_sampled'] = w_sampled = w_mu + tf.random_normal(
        tf.shape(w_log_sigma)) * tf.exp(0.5 * w_log_sigma)

    outputs['z_mu'] = z_mu = z_mu_layer(temp_out)
    outputs['z_log_sigma'] = z_log_sigma = z_log_sigma_layer(temp_out)
    # reparametrization
    outputs['z_sampled'] = z_sampled = z_mu + tf.random_normal(
        tf.shape(z_log_sigma)) * tf.exp(0.5 * z_log_sigma)

    # posterior p(z|w,c)
    conv_7 = Conv2D(filters=64,
                    kernel_size=1,
                    strides=1,
                    padding='same',
                    name='p_z_wc/1x1convlayer',
                    activation=relu)
    z_wc_mu_layer = Conv2D(filters=config.dim_z * config.dim_c,
                           kernel_size=1,
                           strides=1,
                           padding='same',
                           name='p_z_wc/z_wc_mu')
    z_wc_log_sigma_layer = Conv2D(filters=config.dim_z * config.dim_c,
                                  kernel_size=1,
                                  strides=1,
                                  padding='same',
                                  name='p_z_wc/z_wc_log_sigma')

    mid = conv_7(w_sampled)
    z_wc_mu = z_wc_mu_layer(mid)
    z_wc_log_sigma = z_wc_log_sigma_layer(mid)
    z_wc_log_sigma_inv = tf.nn.bias_add(
        z_wc_log_sigma,
        bias=tf.Variable(
            tf.constant(0.1,
                        shape=[z_wc_log_sigma.get_shape()[-1]],
                        dtype=tf.float32)))
    outputs['z_wc_mus'] = z_wc_mus = tf.reshape(z_wc_mu, [
        -1,
        z_wc_mu.get_shape().as_list()[1],
        z_wc_mu.get_shape().as_list()[2], config.dim_z, config.dim_c
    ])
    z_wc_sigma_shape = z_wc_log_sigma_inv.get_shape().as_list()
    outputs['z_wc_log_sigma_invs'] = z_wc_log_sigma_invs = tf.reshape(
        z_wc_log_sigma_inv, [
            -1, z_wc_sigma_shape[1], z_wc_sigma_shape[2], config.dim_z,
            config.dim_c
        ])
    # reparametrization
    outputs['z_wc_sampled'] = z_wc_mus + tf.random_normal(
        tf.shape(z_wc_log_sigma_invs)) * tf.exp(z_wc_log_sigma_invs)

    # decoder p(x|z)
    decoder = build_unified_decoder(config.outputWidth,
                                    config.intermediateResolutions,
                                    config.numChannels)
    for layer in decoder:
        temp_out = layer(temp_out)

    outputs['xz_mu'] = temp_out

    # prior p(c)
    z_sample = tf.tile(tf.expand_dims(z_sampled, -1),
                       [1, 1, 1, 1, config.dim_c])
    loglh = -0.5 * (tf.squared_difference(z_sample, z_wc_mus) * tf.exp(
        z_wc_log_sigma_invs)) - z_wc_log_sigma_invs + tf.log(np.pi)
    loglh_sum = tf.reduce_sum(loglh, 3)
    outputs['pc_logit'] = loglh_sum
    outputs['pc'] = tf.nn.softmax(loglh_sum)

    return outputs
Example #8
0
def adversarial_autoencoder(z, x, dropout_rate, dropout, config):
    outputs = {}

    with tf.variable_scope('Encoder'):
        encoder = build_unified_encoder(x.get_shape().as_list(),
                                        config.intermediateResolutions)

        temp_out = x
        for layer in encoder:
            temp_out = layer(temp_out)

    with tf.variable_scope("Bottleneck"):
        intermediate_conv = Conv2D(temp_out.get_shape().as_list()[3] // 8,
                                   1,
                                   padding='same')
        intermediate_conv_reverse = Conv2D(temp_out.get_shape().as_list()[3],
                                           1,
                                           padding='same')
        dropout_layer = Dropout(dropout_rate)
        temp_out = intermediate_conv(temp_out)

        reshape = temp_out.get_shape().as_list()[1:]
        z_layer = Dense(config.zDim)
        dec_dense = Dense(np.prod(reshape))

        outputs['z_'] = z_ = dropout_layer(z_layer(Flatten()(temp_out)),
                                           dropout)
        reshaped = tf.reshape(dropout_layer(dec_dense(z_), dropout),
                              [-1, *reshape])
        temp_out = intermediate_conv_reverse(reshaped)

    with tf.variable_scope('Decoder'):
        decoder = build_unified_decoder(config.outputWidth,
                                        config.intermediateResolutions,
                                        config.numChannels)

        # Decode: z -> x_hat
        for layer in decoder:
            temp_out = layer(temp_out)

        outputs['x_hat'] = temp_out

    # Discriminator
    with tf.variable_scope('Discriminator'):
        discriminator = [
            Dense(50, activation=leaky_relu),
            Dense(50, activation=leaky_relu),
            Dense(1)
        ]

        # fake
        temp_out = z_
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_'] = temp_out

        # real
        temp_out = z
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d'] = temp_out

        # adding noise
        epsilon = tf.random_uniform([config.batchsize, 1],
                                    minval=0.,
                                    maxval=1.)
        outputs['z_hat'] = z_hat = z + epsilon * (z - z_)

        temp_out = z_hat
        for layer in discriminator:
            temp_out = layer(temp_out)
        outputs['d_hat'] = temp_out

    return outputs
Example #9
0
def gaussian_mixture_variational_autoencoder(x, dropout_rate, dropout, config):
    layers = {}
    # encoding network q(z|x) and q(w|x)
    with tf.variable_scope('Encoder'):
        encoder = build_unified_encoder(x.get_shape().as_list(), config.intermediateResolutions)

        temp_out = x
        for layer in encoder:
            temp_out = layer(temp_out)

    with tf.variable_scope("Bottleneck"):
        intermediate_conv = Conv2D(temp_out.get_shape().as_list()[3] // 8, 1, padding='same')
        intermediate_conv_reverse = Conv2D(temp_out.get_shape().as_list()[3], 1, padding='same')
        dropout_layer = Dropout(dropout_rate)
        temp_out = intermediate_conv(temp_out)
        reshape = temp_out.get_shape().as_list()[1:]

        w_mu_layer = Dense(config.dim_w)
        w_log_sigma_layer = Dense(config.dim_w)

        z_mu_layer = Dense(config.dim_z)
        z_log_sigma_layer = Dense(config.dim_z)
        dec_dense = Dense(np.prod(reshape))

        flatten = Flatten()(temp_out)

        layers['w_mu'] = w_mu = dropout_layer(w_mu_layer(flatten), dropout)
        layers['w_log_sigma'] = w_log_sigma = dropout_layer(w_log_sigma_layer(flatten), dropout)
        layers['w_sampled'] = w_sampled = w_mu + tf.random_normal(tf.shape(w_log_sigma)) * tf.exp(0.5 * w_log_sigma)

        layers['z_mu'] = z_mu = dropout_layer(z_mu_layer(flatten), dropout)
        layers['z_log_sigma'] = z_log_sigma = dropout_layer(z_log_sigma_layer(flatten))
        layers['z_sampled'] = z_sampled = z_mu + tf.random_normal(tf.shape(z_log_sigma)) * tf.exp(0.5 * z_log_sigma)

        temp_out = intermediate_conv_reverse(tf.reshape(dropout_layer(dec_dense(z_sampled), dropout), [-1, *reshape]))

    # posterior p(z|w,c)
    z_wc_mu_layer = Dense(config.dim_z * config.dim_c)
    z_wc_log_sigma_layer = Dense(config.dim_z * config.dim_c)

    z_wc_mu = z_wc_mu_layer(w_sampled)
    z_wc_log_sigma = z_wc_log_sigma_layer(w_sampled)
    z_wc_log_sigma_inv = tf.nn.bias_add(z_wc_log_sigma, bias=tf.Variable(tf.constant(0.1, shape=[z_wc_log_sigma.get_shape()[-1]], dtype=tf.float32)))
    layers['z_wc_mus'] = z_wc_mus = tf.reshape(z_wc_mu, [-1, config.dim_z, config.dim_c])
    layers['z_wc_log_sigma_invs'] = z_wc_log_sigma_invs = tf.reshape(z_wc_log_sigma_inv, [-1, config.dim_z, config.dim_c])
    # reparametrization
    layers['z_wc_sampled'] = z_wc_mus + tf.random_normal(tf.shape(z_wc_log_sigma_invs)) * tf.exp(z_wc_log_sigma_invs)

    # decoder p(x|z)
    with tf.variable_scope('Decoder'):
        decoder = build_unified_decoder(config.outputWidth, config.intermediateResolutions, config.numChannels)

        for layer in decoder:
            temp_out = layer(temp_out)

        layers['xz_mu'] = temp_out

    # prior p(c)
    z_sample = tf.tile(tf.expand_dims(z_sampled, -1), [1, 1, config.dim_c])
    loglh = -0.5 * (tf.squared_difference(z_sample, z_wc_mus) * tf.exp(z_wc_log_sigma_invs)) - z_wc_log_sigma_invs + tf.log(np.pi)
    loglh_sum = tf.reduce_sum(loglh, 1)
    layers['pc_logit'] = loglh_sum
    layers['pc'] = tf.nn.softmax(loglh_sum)

    return layers