def decoder_bloc(u, z_corr, mean, var, layer_spec=None):
    z_est = g_gauss(z_corr, u)
    z_est_BN = (z_est - mean) / tf.sqrt(var + tf.constant(1e-10))
    z_est_BN = tf.identity(z_est_BN, name="z_est_BN")

    if layer_spec is not None:
        u = run_transpose_layer(z_est, layer_spec)
        u = batch_normalization(u, output_name="u")

    return u, z_est_BN
Exemplo n.º 2
0
def decoder_bloc(u, z_corr, mean, var, layer_spec=None):
    # Performs the decoding operations of a corresponding encoder bloc
    # Denoising
    z_est = g_gauss(z_corr, u)

    z_est_BN = (z_est - mean) / tf.sqrt(var + tf.constant(1e-10))
    z_est_BN = tf.identity(z_est_BN, name="z_est_BN")

    # run transposed layer
    if layer_spec is not None:
        u = run_transpose_layer(z_est, layer_spec)
        u = batch_normalization(u, output_name="u")

    return u, z_est_BN
Exemplo n.º 3
0
# Function used to get a tensor from encoder
get_tensor = lambda input_name, num_encoder_bloc, name_tensor: tf.get_default_graph(
).get_tensor_by_name(input_name + "/encoder_bloc_" + str(num_encoder_bloc) +
                     "/" + name_tensor + ":0")

d_cost = []
u = batch_normalization(AE_y_corr, output_name="u_L")
for i in range(L, 0, -1):
    layer_spec = layers[i - 1]

    with tf.variable_scope("decoder_bloc_" + str(i), reuse=tf.AUTO_REUSE):
        # if the layer is max pooling or "flat", the transposed layer is run without creating a decoder bloc.
        if layer_spec["type"] in ["max_pool_2x2", "flat"]:
            h = get_tensor("AE_corrupted", i - 1, "h")
            output_shape = tf.shape(h)
            u = run_transpose_layer(u, layer_spec, output_shape=output_shape)
        else:
            z_corr, z = [
                get_tensor("AE_corrupted", i, "z"),
                get_tensor("AE_clean", i, "z")
            ]
            mean, var = [
                get_tensor("AE_clean", i, "mean"),
                get_tensor("AE_clean", i, "var")
            ]

            u, z_est_BN = decoder_bloc(u,
                                       z_corr,
                                       mean,
                                       var,
                                       layer_spec=layer_spec)