예제 #1
0
    def increase_resolution(x, times, num_filters, name):

        with tf.variable_scope(name):
            nett = x

            for i in range(times):
                nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2)
                nett = layers.conv2D(nett,
                                     'z%d_post' % i,
                                     num_filters=num_filters,
                                     normalisation=norm,
                                     training=training)

        return nett
예제 #2
0
def resize_features(features, size, name):
        
    for f in range(len(features)):
        
        this_feature = features[f]
        this_feature_resized = layers.bilinear_upsample2D(this_feature,
                                                          size,
                                                          name + str(f))
        if f is 0:
            features_resized = this_feature_resized
        else:
            features_resized = tf.concat((features_resized,
                                          this_feature_resized), axis=-1)
            
    return features_resized
예제 #3
0
    def increase_resolution(x, times, name):

        with tf.variable_scope(name):
            nett = x

            for i in range(times):
                nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2)
                nC = nett.get_shape().as_list()[3]
                nett = layers.conv2D(nett,
                                     'z%d_post' % i,
                                     num_filters=min(nC * 2, max_channels),
                                     normalisation=norm,
                                     training=training)

        return nett
예제 #4
0
def hybrid(x,
           s_oh,
           zdim_0,
           training,
           scope_reuse=False,
           norm=tfnorm.batch_norm,
           **kwargs):

    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('posterior') as scope:

        if scope_reuse:
            scope.reuse_variables()

        full_cov_list = kwargs.get('full_cov_list', None)

        n0 = kwargs.get('n0', 32)
        latent_levels = kwargs.get('latent_levels', 5)
        resolution_levels = kwargs.get('resolution_levels', 7)

        spatial_xdim = x.get_shape().as_list()[1:3]

        full_latent_dependencies = kwargs.get('full_latent_dependencies',
                                              False)

        pre_z = [None] * resolution_levels

        mu = [None] * latent_levels
        sigma = [None] * latent_levels
        z = [None] * latent_levels

        z_ups_mat = []
        for i in range(latent_levels):
            z_ups_mat.append(
                [None] *
                latent_levels)  # encoding [original resolution][upsampled to]

        # Generate pre_z's
        for i in range(resolution_levels):

            if i == 0:
                net = tf.concat([x, s_oh - 0.5], axis=-1)
            else:
                net = layers.averagepool2D(pre_z[i - 1])

            net = layers.conv2D(net,
                                'z%d_pre_1' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_2' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_3' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)

            pre_z[i] = net

        # Generate z's
        for i in reversed(range(latent_levels)):

            spatial_zdim = [
                d // 2**(i + resolution_levels - latent_levels)
                for d in spatial_xdim
            ]
            spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1]

            if i == latent_levels - 1:

                mu[i] = layers.conv2D(pre_z[i + resolution_levels -
                                            latent_levels],
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity)

                if full_cov_list[i] == True:

                    l = layers.dense_layer(
                        pre_z[i + resolution_levels - latent_levels],
                        'z%d_sigma' % i,
                        hidden_units=zdim_0 * spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2,
                        activation=tf.identity)
                    l = tf.reshape(l, [
                        -1, zdim_0, spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2
                    ])
                    Lp = tf.contrib.distributions.fill_triangular(l)
                    L = tf.linalg.set_diag(
                        Lp, tf.nn.softplus(tf.linalg.diag_part(Lp))
                    )  # Cholesky factors must have positive diagonal

                    sigma[i] = L

                    eps = tf.random_normal(tf.shape(mu[i]))
                    eps = tf.transpose(eps, perm=[0, 3, 1, 2])
                    bs = tf.shape(x)[0]
                    eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1]))

                    eps_tmp = tf.matmul(sigma[i], eps)
                    eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1])
                    eps_tmp = tf.reshape(
                        eps_tmp,
                        [bs, spatial_zdim[0], spatial_zdim[1], zdim_0])

                    z[i] = mu[i] + eps_tmp

                else:

                    sigma[i] = layers.conv2D(pre_z[i + resolution_levels -
                                                   latent_levels],
                                             'z%d_sigma' % i,
                                             num_filters=zdim_0,
                                             activation=tf.nn.softplus,
                                             kernel_size=(1, 1))
                    z[i] = mu[i] + sigma[i] * tf.random_normal(
                        tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            else:

                for j in reversed(range(0, i + 1)):

                    z_below_ups = layers.bilinear_upsample2D(
                        z_ups_mat[j + 1][i + 1], factor=2, name='ups')
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_1' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_2' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)

                    z_ups_mat[j][i + 1] = z_below_ups

                if full_latent_dependencies:
                    z_input = tf.concat(
                        [pre_z[i + resolution_levels - latent_levels]] +
                        z_ups_mat[i][(i + 1):latent_levels],
                        axis=3,
                        name='concat_%d' % i)
                else:
                    z_input = tf.concat([
                        pre_z[i + resolution_levels - latent_levels],
                        z_ups_mat[i][i + 1]
                    ],
                                        axis=3,
                                        name='concat_%d' % i)

                z_input = layers.conv2D(z_input,
                                        'z%d_input_1' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)
                z_input = layers.conv2D(z_input,
                                        'z%d_input_2' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)

                mu[i] = layers.conv2D(z_input,
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity,
                                      kernel_size=(1, 1))

                if full_cov_list[i] == True:

                    l = layers.dense_layer(z_input,
                                           'z%d_sigma' % i,
                                           hidden_units=zdim_0 *
                                           spatial_cov_dim *
                                           (spatial_cov_dim + 1) // 2,
                                           activation=tf.identity)
                    l = tf.reshape(l, [
                        -1, zdim_0, spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2
                    ])
                    Lp = tf.contrib.distributions.fill_triangular(l)
                    L = tf.linalg.set_diag(
                        Lp, tf.nn.softplus(tf.linalg.diag_part(Lp)))

                    sigma[i] = L

                    eps = tf.random_normal(tf.shape(mu[i]))
                    eps = tf.transpose(eps, perm=[0, 3, 1, 2])
                    bs = tf.shape(x)[0]
                    eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1]))

                    eps_tmp = tf.matmul(sigma[i], eps)
                    eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1])
                    eps_tmp = tf.reshape(
                        eps_tmp,
                        [bs, spatial_zdim[0], spatial_zdim[1], zdim_0])

                    z[i] = mu[i] + eps_tmp

                else:

                    sigma[i] = layers.conv2D(z_input,
                                             'z%d_sigma' % i,
                                             num_filters=zdim_0,
                                             activation=tf.nn.softplus,
                                             kernel_size=(1, 1))
                    z[i] = mu[i] + sigma[i] * tf.random_normal(
                        tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            z_ups_mat[i][i] = z[i]

    return z, mu, sigma
예제 #5
0
def proposed(z_list,
             training,
             image_size,
             n_classes,
             scope_reuse=False,
             norm=tfnorm.batch_norm,
             rank=10,
             diagonal=False,
             **kwargs):
    x = kwargs.get('x')

    resolution_levels = kwargs.get('resolution_levels', 7)
    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    conv_unit = layers.conv2D
    deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2)

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        net = dec[-1][-1]

        recomb = conv_unit(net,
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        epsilon = 1e-5

        mean = layers.conv2D(recomb,
                             'mean',
                             num_filters=n_classes,
                             kernel_size=(1, 1),
                             activation=tf.identity)
        log_cov_diag = layers.conv2D(recomb,
                                     'diag',
                                     num_filters=n_classes,
                                     kernel_size=(1, 1),
                                     activation=tf.identity)
        cov_factor = layers.conv2D(recomb,
                                   'factor',
                                   num_filters=n_classes * rank,
                                   kernel_size=(1, 1),
                                   activation=tf.identity)

        shape = image_size[:-1] + (n_classes, )
        flat_size = np.prod(shape)
        mean = tf.reshape(mean, [-1, flat_size])
        cov_diag = tf.exp(tf.reshape(log_cov_diag, [-1, flat_size])) + epsilon
        cov_factor = tf.reshape(cov_factor, [-1, flat_size, rank])
        if diagonal:
            dist = DiagonalMultivariateNormal(loc=mean, cov_diag=cov_diag)
        else:
            dist = LowRankMultivariateNormal(loc=mean,
                                             cov_diag=cov_diag,
                                             cov_factor=cov_factor)

        s = dist.rsample((1, ))
        s = tf.reshape(s, (-1, ) + shape)
        return [[s], dist]
예제 #6
0
def phiseg(z_list,
           training,
           image_size,
           n_classes,
           scope_reuse=False,
           norm=tfnorm.batch_norm,
           **kwargs):
    """
    This is a U-NET like arch with skips before and after latent space and a rather simple decoder
    """

    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    def increase_resolution(x, times, num_filters, name):

        with tf.variable_scope(name):
            nett = x

            for i in range(times):
                nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2)
                nett = layers.conv2D(nett,
                                     'z%d_post' % i,
                                     num_filters=num_filters,
                                     normalisation=norm,
                                     training=training)

        return nett

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        resolution_levels = kwargs.get('resolution_levels', 7)
        latent_levels = kwargs.get('latent_levels', 5)
        lvl_diff = resolution_levels - latent_levels

        post_z = [None] * latent_levels
        post_c = [None] * latent_levels

        s = [None] * latent_levels

        # Generate post_z
        for i in range(latent_levels):
            net = layers.conv2D(z_list[i],
                                'z%d_post_1' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_post_2' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = increase_resolution(net,
                                      resolution_levels - latent_levels,
                                      num_filters=num_channels[i],
                                      name='preups_%d' % i)

            post_z[i] = net

        # Upstream path
        post_c[latent_levels - 1] = post_z[latent_levels - 1]

        for i in reversed(range(latent_levels - 1)):
            ups_below = layers.bilinear_upsample2D(post_c[i + 1],
                                                   name='post_z%d_ups' %
                                                   (i + 1),
                                                   factor=2)
            ups_below = layers.conv2D(ups_below,
                                      'post_z%d_ups_c' % (i + 1),
                                      num_filters=num_channels[i],
                                      normalisation=norm,
                                      training=training)

            concat = tf.concat([post_z[i], ups_below],
                               axis=3,
                               name='concat_%d' % i)

            net = layers.conv2D(concat,
                                'post_c_%d_1' % i,
                                num_filters=num_channels[i + lvl_diff],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'post_c_%d_2' % i,
                                num_filters=num_channels[i + lvl_diff],
                                normalisation=norm,
                                training=training)

            post_c[i] = net

        # Outputs
        for i in range(latent_levels):
            s_in = layers.conv2D(post_c[i],
                                 'y_lvl%d' % i,
                                 num_filters=n_classes,
                                 kernel_size=(1, 1),
                                 activation=tf.identity)
            s[i] = tf.image.resize_images(
                s_in,
                image_size[0:2],
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        return s
예제 #7
0
def prob_unet2D(z_list,
                training,
                image_size,
                n_classes,
                scope_reuse=False,
                norm=tfnorm.batch_norm,
                **kwargs):
    x = kwargs.get('x')

    z = z_list[0]

    resolution_levels = kwargs.get('resolution_levels', 7)
    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    conv_unit = layers.conv2D
    deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2)

    bs = tf.shape(x)[0]
    zdim = z.get_shape().as_list()[-1]

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        z_t = tf.reshape(z, tf.stack((bs, 1, 1, zdim)))

        broadcast_z = tf.tile(z_t, (1, image_size[0], image_size[1], 1))

        net = tf.concat([dec[-1][-1], broadcast_z], axis=-1)

        recomb = conv_unit(net,
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        s = [
            layers.conv2D(recomb,
                          'prediction',
                          num_filters=n_classes,
                          kernel_size=(1, 1),
                          activation=tf.identity)
        ]

        return s
예제 #8
0
def prob_unet2D_arch(
        x,
        training,
        nlabels,
        n0=32,
        resolution_levels=7,
        norm=tfnorm.batch_norm,
        conv_unit=layers.conv2D,
        deconv_unit=lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2),
        scope_reuse=False,
        return_net=False):

    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        recomb = conv_unit(dec[-1][-1],
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        s = layers.conv2D(recomb,
                          'prediction',
                          num_filters=nlabels,
                          kernel_size=(1, 1),
                          activation=tf.identity)

        return s
예제 #9
0
def phiseg(x,
           s_oh,
           zdim_0,
           training,
           scope_reuse=False,
           norm=tfnorm.batch_norm,
           **kwargs):

    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('posterior') as scope:

        if scope_reuse:
            scope.reuse_variables()

        full_cov_list = kwargs.get('full_cov_list', None)

        n0 = kwargs.get('n0', 32)
        latent_levels = kwargs.get('latent_levels', 5)
        resolution_levels = kwargs.get('resolution_levels', 7)

        spatial_xdim = x.get_shape().as_list()[1:3]

        pre_z = [None] * resolution_levels

        mu = [None] * latent_levels
        sigma = [None] * latent_levels
        z = [None] * latent_levels

        z_ups_mat = []
        for i in range(latent_levels):
            z_ups_mat.append(
                [None] *
                latent_levels)  # encoding [original resolution][upsampled to]

        # Generate pre_z's
        for i in range(resolution_levels):

            if i == 0:
                net = tf.concat([x, s_oh - 0.5], axis=-1)
            else:
                net = layers.averagepool2D(pre_z[i - 1])

            net = layers.conv2D(net,
                                'z%d_pre_1' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_2' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_3' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)

            pre_z[i] = net

        # Generate z's
        for i in reversed(range(latent_levels)):

            spatial_zdim = [
                d // 2**(i + resolution_levels - latent_levels)
                for d in spatial_xdim
            ]
            spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1]

            if i == latent_levels - 1:

                mu[i] = layers.conv2D(pre_z[i + resolution_levels -
                                            latent_levels],
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity)

                sigma[i] = layers.conv2D(pre_z[i + resolution_levels -
                                               latent_levels],
                                         'z%d_sigma' % i,
                                         num_filters=zdim_0,
                                         activation=tf.nn.softplus,
                                         kernel_size=(1, 1))
                z[i] = mu[i] + sigma[i] * tf.random_normal(
                    tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            else:

                for j in reversed(range(0, i + 1)):

                    z_below_ups = layers.bilinear_upsample2D(
                        z_ups_mat[j + 1][i + 1], factor=2, name='ups')
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_1' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_2' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)

                    z_ups_mat[j][i + 1] = z_below_ups

                z_input = tf.concat([
                    pre_z[i + resolution_levels - latent_levels],
                    z_ups_mat[i][i + 1]
                ],
                                    axis=3,
                                    name='concat_%d' % i)

                z_input = layers.conv2D(z_input,
                                        'z%d_input_1' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)
                z_input = layers.conv2D(z_input,
                                        'z%d_input_2' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)

                mu[i] = layers.conv2D(z_input,
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity,
                                      kernel_size=(1, 1))

                sigma[i] = layers.conv2D(z_input,
                                         'z%d_sigma' % i,
                                         num_filters=zdim_0,
                                         activation=tf.nn.softplus,
                                         kernel_size=(1, 1))
                z[i] = mu[i] + sigma[i] * tf.random_normal(
                    tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            z_ups_mat[i][i] = z[i]

    return z, mu, sigma
예제 #10
0
def unet2D_i2l(images,
               nlabels,
               training_pl,
               scope_reuse = False): 

    n0 = 16
    n1, n2, n3, n4 = 1*n0, 2*n0, 4*n0, 8*n0
    
    with tf.variable_scope('i2l_mapper') as scope:
        
        if scope_reuse:
            scope.reuse_variables()
        
        # ====================================
        # 1st Conv block - two conv layers, followed by max-pooling
        # ====================================
        conv1_1 = layers.conv2D_layer_bn(x=images, name='conv1_1', num_filters=n1, training = training_pl)
        conv1_2 = layers.conv2D_layer_bn(x=conv1_1, name='conv1_2', num_filters=n1, training = training_pl)
        pool1 = layers.max_pool_layer2d(conv1_2)
    
        # ====================================
        # 2nd Conv block
        # ====================================
        conv2_1 = layers.conv2D_layer_bn(x=pool1, name='conv2_1', num_filters=n2, training = training_pl)
        conv2_2 = layers.conv2D_layer_bn(x=conv2_1, name='conv2_2', num_filters=n2, training = training_pl)
        pool2 = layers.max_pool_layer2d(conv2_2)
    
        # ====================================
        # 3rd Conv block
        # ====================================
        conv3_1 = layers.conv2D_layer_bn(x=pool2, name='conv3_1', num_filters=n3, training = training_pl)
        conv3_2 = layers.conv2D_layer_bn(x=conv3_1, name='conv3_2', num_filters=n3, training = training_pl)
        pool3 = layers.max_pool_layer2d(conv3_1)
    
        # ====================================
        # 4th Conv block
        # ====================================
        conv4_1 = layers.conv2D_layer_bn(x=pool3, name='conv4_1', num_filters=n4, training = training_pl)
        conv4_2 = layers.conv2D_layer_bn(x=conv4_1, name='conv4_2', num_filters=n4, training = training_pl)
    
        # ====================================
        # Upsampling via bilinear upsampling, concatenation (skip connection), followed by 2 conv layers
        # ====================================
        deconv3 = layers.bilinear_upsample2D(conv4_2, size = (tf.shape(conv3_2)[1],tf.shape(conv3_2)[2]), name='upconv3')
        concat3 = tf.concat([deconv3, conv3_2], axis=-1)        
        conv5_1 = layers.conv2D_layer_bn(x=concat3, name='conv5_1', num_filters=n3, training = training_pl)
        conv5_2 = layers.conv2D_layer_bn(x=conv5_1, name='conv5_2', num_filters=n3, training = training_pl)
    
        # ====================================
        # Upsampling via bilinear upsampling, concatenation (skip connection), followed by 2 conv layers
        # ====================================
        deconv2 = layers.bilinear_upsample2D(conv5_2, size = (tf.shape(conv2_2)[1],tf.shape(conv2_2)[2]), name='upconv2')
        concat2 = tf.concat([deconv2, conv2_2], axis=-1)        
        conv6_1 = layers.conv2D_layer_bn(x=concat2, name='conv6_1', num_filters=n2, training = training_pl)
        conv6_2 = layers.conv2D_layer_bn(x=conv6_1, name='conv6_2', num_filters=n2, training = training_pl)
    
        # ====================================
        # Upsampling via bilinear upsampling, concatenation (skip connection), followed by 2 conv layers
        # ====================================
        deconv1 = layers.bilinear_upsample2D(conv6_2, size = (tf.shape(conv1_2)[1],tf.shape(conv1_2)[2]), name='upconv1')
        concat1 = tf.concat([deconv1, conv1_2], axis=-1)        
        conv7_1 = layers.conv2D_layer_bn(x=concat1, name='conv7_1', num_filters=n1, training = training_pl)
        conv7_2 = layers.conv2D_layer_bn(x=conv7_1, name='conv7_2', num_filters=n1, training = training_pl)
    
        # ====================================
        # Final conv layer - without batch normalization or activation
        # ====================================
        pred = layers.conv2D_layer(x=conv7_2, name='pred', num_filters=nlabels, kernel_size=1)
        
    return pool1, pool2, pool3, conv4_2, conv5_2, conv6_2, conv7_2, pred