コード例 #1
0
ファイル: priors.py プロジェクト: templeblock/PHiSeg-code
def prob_unet2D(z_list, x, zdim_0, n_classes, generation_mode, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs):

    resolution_levels = kwargs.get('resolution_levels', 7)

    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2*n0, 4*n0,6*n0, 6*n0, 6*n0, 6*n0]

    conv_unit = layers.conv2D

    with tf.variable_scope('prior') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        for ii in range(resolution_levels):

            enc.append([])

            # In first layer set input to x rather than max pooling
            if ii == 0:
                enc[ii].append(x)
            else:
                enc[ii].append(layers.averagepool2D(enc[ii-1][-1]))

            enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias))
            enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias))
            enc[ii].append(conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias))

        mu_p = conv_unit(enc[-1][-1], 'pre_mu', num_filters=zdim_0, kernel_size=(1, 1), activation=tf.identity)
        mu = [layers.global_averagepool2D(mu_p)]

        sigma_p = conv_unit(enc[-1][-1], 'pre_sigma', num_filters=zdim_0, kernel_size=(1, 1), activation=tf.nn.softplus)
        sigma = [layers.global_averagepool2D(sigma_p)]

        z = [mu[0] + sigma[0] * tf.random_normal(tf.shape(mu[0]), 0, 1, dtype=tf.float32)]

        return z, mu, sigma
コード例 #2
0
ファイル: posteriors.py プロジェクト: DLwbm123/PHiSeg-code
def hybrid(x,
           s_oh,
           zdim_0,
           training,
           scope_reuse=False,
           norm=tfnorm.batch_norm,
           **kwargs):

    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('posterior') as scope:

        if scope_reuse:
            scope.reuse_variables()

        full_cov_list = kwargs.get('full_cov_list', None)

        n0 = kwargs.get('n0', 32)
        latent_levels = kwargs.get('latent_levels', 5)
        resolution_levels = kwargs.get('resolution_levels', 7)

        spatial_xdim = x.get_shape().as_list()[1:3]

        full_latent_dependencies = kwargs.get('full_latent_dependencies',
                                              False)

        pre_z = [None] * resolution_levels

        mu = [None] * latent_levels
        sigma = [None] * latent_levels
        z = [None] * latent_levels

        z_ups_mat = []
        for i in range(latent_levels):
            z_ups_mat.append(
                [None] *
                latent_levels)  # encoding [original resolution][upsampled to]

        # Generate pre_z's
        for i in range(resolution_levels):

            if i == 0:
                net = tf.concat([x, s_oh - 0.5], axis=-1)
            else:
                net = layers.averagepool2D(pre_z[i - 1])

            net = layers.conv2D(net,
                                'z%d_pre_1' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_2' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_3' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)

            pre_z[i] = net

        # Generate z's
        for i in reversed(range(latent_levels)):

            spatial_zdim = [
                d // 2**(i + resolution_levels - latent_levels)
                for d in spatial_xdim
            ]
            spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1]

            if i == latent_levels - 1:

                mu[i] = layers.conv2D(pre_z[i + resolution_levels -
                                            latent_levels],
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity)

                if full_cov_list[i] == True:

                    l = layers.dense_layer(
                        pre_z[i + resolution_levels - latent_levels],
                        'z%d_sigma' % i,
                        hidden_units=zdim_0 * spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2,
                        activation=tf.identity)
                    l = tf.reshape(l, [
                        -1, zdim_0, spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2
                    ])
                    Lp = tf.contrib.distributions.fill_triangular(l)
                    L = tf.linalg.set_diag(
                        Lp, tf.nn.softplus(tf.linalg.diag_part(Lp))
                    )  # Cholesky factors must have positive diagonal

                    sigma[i] = L

                    eps = tf.random_normal(tf.shape(mu[i]))
                    eps = tf.transpose(eps, perm=[0, 3, 1, 2])
                    bs = tf.shape(x)[0]
                    eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1]))

                    eps_tmp = tf.matmul(sigma[i], eps)
                    eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1])
                    eps_tmp = tf.reshape(
                        eps_tmp,
                        [bs, spatial_zdim[0], spatial_zdim[1], zdim_0])

                    z[i] = mu[i] + eps_tmp

                else:

                    sigma[i] = layers.conv2D(pre_z[i + resolution_levels -
                                                   latent_levels],
                                             'z%d_sigma' % i,
                                             num_filters=zdim_0,
                                             activation=tf.nn.softplus,
                                             kernel_size=(1, 1))
                    z[i] = mu[i] + sigma[i] * tf.random_normal(
                        tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            else:

                for j in reversed(range(0, i + 1)):

                    z_below_ups = layers.bilinear_upsample2D(
                        z_ups_mat[j + 1][i + 1], factor=2, name='ups')
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_1' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_2' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)

                    z_ups_mat[j][i + 1] = z_below_ups

                if full_latent_dependencies:
                    z_input = tf.concat(
                        [pre_z[i + resolution_levels - latent_levels]] +
                        z_ups_mat[i][(i + 1):latent_levels],
                        axis=3,
                        name='concat_%d' % i)
                else:
                    z_input = tf.concat([
                        pre_z[i + resolution_levels - latent_levels],
                        z_ups_mat[i][i + 1]
                    ],
                                        axis=3,
                                        name='concat_%d' % i)

                z_input = layers.conv2D(z_input,
                                        'z%d_input_1' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)
                z_input = layers.conv2D(z_input,
                                        'z%d_input_2' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)

                mu[i] = layers.conv2D(z_input,
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity,
                                      kernel_size=(1, 1))

                if full_cov_list[i] == True:

                    l = layers.dense_layer(z_input,
                                           'z%d_sigma' % i,
                                           hidden_units=zdim_0 *
                                           spatial_cov_dim *
                                           (spatial_cov_dim + 1) // 2,
                                           activation=tf.identity)
                    l = tf.reshape(l, [
                        -1, zdim_0, spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2
                    ])
                    Lp = tf.contrib.distributions.fill_triangular(l)
                    L = tf.linalg.set_diag(
                        Lp, tf.nn.softplus(tf.linalg.diag_part(Lp)))

                    sigma[i] = L

                    eps = tf.random_normal(tf.shape(mu[i]))
                    eps = tf.transpose(eps, perm=[0, 3, 1, 2])
                    bs = tf.shape(x)[0]
                    eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1]))

                    eps_tmp = tf.matmul(sigma[i], eps)
                    eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1])
                    eps_tmp = tf.reshape(
                        eps_tmp,
                        [bs, spatial_zdim[0], spatial_zdim[1], zdim_0])

                    z[i] = mu[i] + eps_tmp

                else:

                    sigma[i] = layers.conv2D(z_input,
                                             'z%d_sigma' % i,
                                             num_filters=zdim_0,
                                             activation=tf.nn.softplus,
                                             kernel_size=(1, 1))
                    z[i] = mu[i] + sigma[i] * tf.random_normal(
                        tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            z_ups_mat[i][i] = z[i]

    return z, mu, sigma
コード例 #3
0
ファイル: likelihoods.py プロジェクト: Tt199919/PHiSeg-code
def proposed(z_list,
             training,
             image_size,
             n_classes,
             scope_reuse=False,
             norm=tfnorm.batch_norm,
             rank=10,
             diagonal=False,
             **kwargs):
    x = kwargs.get('x')

    resolution_levels = kwargs.get('resolution_levels', 7)
    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    conv_unit = layers.conv2D
    deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2)

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        net = dec[-1][-1]

        recomb = conv_unit(net,
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        epsilon = 1e-5

        mean = layers.conv2D(recomb,
                             'mean',
                             num_filters=n_classes,
                             kernel_size=(1, 1),
                             activation=tf.identity)
        log_cov_diag = layers.conv2D(recomb,
                                     'diag',
                                     num_filters=n_classes,
                                     kernel_size=(1, 1),
                                     activation=tf.identity)
        cov_factor = layers.conv2D(recomb,
                                   'factor',
                                   num_filters=n_classes * rank,
                                   kernel_size=(1, 1),
                                   activation=tf.identity)

        shape = image_size[:-1] + (n_classes, )
        flat_size = np.prod(shape)
        mean = tf.reshape(mean, [-1, flat_size])
        cov_diag = tf.exp(tf.reshape(log_cov_diag, [-1, flat_size])) + epsilon
        cov_factor = tf.reshape(cov_factor, [-1, flat_size, rank])
        if diagonal:
            dist = DiagonalMultivariateNormal(loc=mean, cov_diag=cov_diag)
        else:
            dist = LowRankMultivariateNormal(loc=mean,
                                             cov_diag=cov_diag,
                                             cov_factor=cov_factor)

        s = dist.rsample((1, ))
        s = tf.reshape(s, (-1, ) + shape)
        return [[s], dist]
コード例 #4
0
ファイル: likelihoods.py プロジェクト: Tt199919/PHiSeg-code
def prob_unet2D(z_list,
                training,
                image_size,
                n_classes,
                scope_reuse=False,
                norm=tfnorm.batch_norm,
                **kwargs):
    x = kwargs.get('x')

    z = z_list[0]

    resolution_levels = kwargs.get('resolution_levels', 7)
    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    conv_unit = layers.conv2D
    deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2)

    bs = tf.shape(x)[0]
    zdim = z.get_shape().as_list()[-1]

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        z_t = tf.reshape(z, tf.stack((bs, 1, 1, zdim)))

        broadcast_z = tf.tile(z_t, (1, image_size[0], image_size[1], 1))

        net = tf.concat([dec[-1][-1], broadcast_z], axis=-1)

        recomb = conv_unit(net,
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        s = [
            layers.conv2D(recomb,
                          'prediction',
                          num_filters=n_classes,
                          kernel_size=(1, 1),
                          activation=tf.identity)
        ]

        return s
コード例 #5
0
def prob_unet2D_arch(
        x,
        training,
        nlabels,
        n0=32,
        resolution_levels=7,
        norm=tfnorm.batch_norm,
        conv_unit=layers.conv2D,
        deconv_unit=lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2),
        scope_reuse=False,
        return_net=False):

    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        recomb = conv_unit(dec[-1][-1],
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        s = layers.conv2D(recomb,
                          'prediction',
                          num_filters=nlabels,
                          kernel_size=(1, 1),
                          activation=tf.identity)

        return s
コード例 #6
0
def phiseg(x,
           s_oh,
           zdim_0,
           training,
           scope_reuse=False,
           norm=tfnorm.batch_norm,
           **kwargs):

    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('posterior') as scope:

        if scope_reuse:
            scope.reuse_variables()

        full_cov_list = kwargs.get('full_cov_list', None)

        n0 = kwargs.get('n0', 32)
        latent_levels = kwargs.get('latent_levels', 5)
        resolution_levels = kwargs.get('resolution_levels', 7)

        spatial_xdim = x.get_shape().as_list()[1:3]

        pre_z = [None] * resolution_levels

        mu = [None] * latent_levels
        sigma = [None] * latent_levels
        z = [None] * latent_levels

        z_ups_mat = []
        for i in range(latent_levels):
            z_ups_mat.append(
                [None] *
                latent_levels)  # encoding [original resolution][upsampled to]

        # Generate pre_z's
        for i in range(resolution_levels):

            if i == 0:
                net = tf.concat([x, s_oh - 0.5], axis=-1)
            else:
                net = layers.averagepool2D(pre_z[i - 1])

            net = layers.conv2D(net,
                                'z%d_pre_1' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_2' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_3' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)

            pre_z[i] = net

        # Generate z's
        for i in reversed(range(latent_levels)):

            spatial_zdim = [
                d // 2**(i + resolution_levels - latent_levels)
                for d in spatial_xdim
            ]
            spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1]

            if i == latent_levels - 1:

                mu[i] = layers.conv2D(pre_z[i + resolution_levels -
                                            latent_levels],
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity)

                sigma[i] = layers.conv2D(pre_z[i + resolution_levels -
                                               latent_levels],
                                         'z%d_sigma' % i,
                                         num_filters=zdim_0,
                                         activation=tf.nn.softplus,
                                         kernel_size=(1, 1))
                z[i] = mu[i] + sigma[i] * tf.random_normal(
                    tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            else:

                for j in reversed(range(0, i + 1)):

                    z_below_ups = layers.bilinear_upsample2D(
                        z_ups_mat[j + 1][i + 1], factor=2, name='ups')
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_1' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_2' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)

                    z_ups_mat[j][i + 1] = z_below_ups

                z_input = tf.concat([
                    pre_z[i + resolution_levels - latent_levels],
                    z_ups_mat[i][i + 1]
                ],
                                    axis=3,
                                    name='concat_%d' % i)

                z_input = layers.conv2D(z_input,
                                        'z%d_input_1' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)
                z_input = layers.conv2D(z_input,
                                        'z%d_input_2' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)

                mu[i] = layers.conv2D(z_input,
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity,
                                      kernel_size=(1, 1))

                sigma[i] = layers.conv2D(z_input,
                                         'z%d_sigma' % i,
                                         num_filters=zdim_0,
                                         activation=tf.nn.softplus,
                                         kernel_size=(1, 1))
                z[i] = mu[i] + sigma[i] * tf.random_normal(
                    tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            z_ups_mat[i][i] = z[i]

    return z, mu, sigma