Exemplo n.º 1
0
def betaVAE_bn(x,
               s_oh,
               zdim_0,
               training,
               scope_reuse=False,
               norm=tfnorm.batch_norm,
               **kwargs):

    resolution_levels = kwargs.get('resolution_levels', 5)
    image_size = x.get_shape().as_list()[1:3]
    final_kernel_size = [s // (2**(resolution_levels - 1)) for s in image_size]

    # POSTERIOR ####################
    with tf.variable_scope('posterior') as scope:

        if scope_reuse:
            scope.reuse_variables()

        n0 = kwargs.get('n0', 32)

        mu_z = []
        sigma_z = []
        z = []
        # Generate pre_z's

        net = tf.concat([x, s_oh - 0.5], axis=-1)

        for ii in range(resolution_levels - 1):
            net = layers.conv2D(net,
                                'q_z_%d' % ii,
                                num_filters=n0 * (ii // 2 + 1),
                                kernel_size=(4, 4),
                                strides=(2, 2),
                                normalisation=norm,
                                training=training)

        net = layers.conv2D(net,
                            'q_z_%d' % resolution_levels,
                            num_filters=n0 * 8,
                            kernel_size=final_kernel_size,
                            strides=(1, 1),
                            padding='VALID',
                            normalisation=norm,
                            training=training)

        mu_z.append(
            layers.dense_layer(net,
                               'z_mu',
                               hidden_units=zdim_0,
                               activation=tf.identity))
        sigma_z.append(
            layers.dense_layer(net,
                               'z_sigma',
                               hidden_units=zdim_0,
                               activation=tf.nn.softplus))

        z.append(mu_z[0] + sigma_z[0] *
                 tf.random_normal(tf.shape(mu_z[0]), 0, 1, dtype=tf.float32))

    return z, mu_z, sigma_z
Exemplo n.º 2
0
def id_res_unet2D(x,
                  training,
                  nlabels,
                  n0=64,
                  resolution_levels=5,
                  norm=tfnorm.batch_norm,
                  scope_reuse=False,
                  return_net=False):

    add_bias = False if norm == tfnorm.batch_norm else True
    input_layer = layers.conv2D(x,
                                training=training,
                                num_filters=n0,
                                name='input_layer',
                                normalisation=norm,
                                add_bias=add_bias)

    return unet2D(input_layer,
                  training,
                  nlabels,
                  n0=n0,
                  resolution_levels=resolution_levels,
                  norm=norm,
                  conv_unit=layers.identity_residual_unit2D,
                  scope_reuse=scope_reuse,
                  return_net=return_net)
Exemplo n.º 3
0
    def reduce_resolution(x, times, name):

        with tf.variable_scope(name):

            nett = x

            for ii in range(times):

                nett = layers.reshape_pool2D_layer(nett)
                nC = nett.get_shape().as_list()[3]
                nett = layers.conv2D(nett, 'down_%d' % ii, num_filters=min(nC//4, max_channels), normalisation=norm, training=training)

        return nett
Exemplo n.º 4
0
    def increase_resolution(x, times, num_filters, name):

        with tf.variable_scope(name):
            nett = x

            for i in range(times):
                nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2)
                nett = layers.conv2D(nett,
                                     'z%d_post' % i,
                                     num_filters=num_filters,
                                     normalisation=norm,
                                     training=training)

        return nett
Exemplo n.º 5
0
    def increase_resolution(x, times, name):

        with tf.variable_scope(name):
            nett = x

            for i in range(times):
                nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2)
                nC = nett.get_shape().as_list()[3]
                nett = layers.conv2D(nett,
                                     'z%d_post' % i,
                                     num_filters=min(nC * 2, max_channels),
                                     normalisation=norm,
                                     training=training)

        return nett
Exemplo n.º 6
0
def unet_T_L(x,
             s_oh,
             zdim_0,
             training,
             scope_reuse=False,
             norm=tfnorm.batch_norm,
             **kwargs):

    # POSTERIOR ####################

    with tf.variable_scope('posterior') as scope:

        if scope_reuse:
            scope.reuse_variables()

        full_cov_list = kwargs.get('full_cov_list', None)

        n0 = kwargs.get('n0', 32)
        max_channel_power = kwargs.get('max_channel_power', 4)
        max_channels = n0 * 2**max_channel_power
        latent_levels = kwargs.get('latent_levels', 4)
        resolution_levels = kwargs.get('resolution_levels', 6)

        spatial_xdim = x.get_shape().as_list()[1:3]

        full_latent_dependencies = kwargs.get('full_latent_dependencies',
                                              False)

        pre_z = [None] * resolution_levels

        mu = [None] * latent_levels
        sigma = [None] * latent_levels
        z = [None] * latent_levels

        z_ups_mat = []
        for i in range(latent_levels):
            z_ups_mat.append(
                [None] *
                latent_levels)  # encoding [original resolution][upsampled to]

        # Generate pre_z's
        for i in range(resolution_levels):

            if i == 0:
                net = tf.concat([x, s_oh - 0.5], axis=-1)
            else:
                net = layers.reshape_pool2D_layer(pre_z[i - 1])

            net = layers.conv2D(net,
                                'z%d_pre_1' % i,
                                num_filters=n0 * (i // 2 + 1),
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_2' % i,
                                num_filters=n0 * (i // 2 + 1),
                                normalisation=norm,
                                training=training)

            pre_z[i] = net

        # Generate z's
        for i in reversed(range(latent_levels)):

            spatial_zdim = [
                d // 2**(i + resolution_levels - latent_levels)
                for d in spatial_xdim
            ]
            spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1]

            if i == latent_levels - 1:

                mu[i] = layers.conv2D(pre_z[i + resolution_levels -
                                            latent_levels],
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity)

                if full_cov_list[i] == True:

                    l = layers.dense_layer(
                        pre_z[i + resolution_levels - latent_levels],
                        'z%d_sigma' % i,
                        hidden_units=zdim_0 * spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2,
                        activation=tf.identity)
                    l = tf.reshape(l, [
                        -1, zdim_0, spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2
                    ])
                    Lp = tf.contrib.distributions.fill_triangular(l)
                    L = tf.linalg.set_diag(
                        Lp, tf.nn.softplus(tf.linalg.diag_part(Lp))
                    )  # Cholesky factors must have positive diagonal

                    sigma[i] = L

                    eps = tf.random_normal(tf.shape(mu[i]))
                    eps = tf.transpose(eps, perm=[0, 3, 1, 2])
                    bs = tf.shape(x)[0]
                    eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1]))

                    eps_tmp = tf.matmul(sigma[i], eps)
                    eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1])
                    eps_tmp = tf.reshape(
                        eps_tmp,
                        [bs, spatial_zdim[0], spatial_zdim[1], zdim_0])

                    z[i] = mu[i] + eps_tmp

                else:

                    sigma[i] = layers.conv2D(pre_z[i + resolution_levels -
                                                   latent_levels],
                                             'z%d_sigma' % i,
                                             num_filters=zdim_0,
                                             activation=tf.nn.softplus,
                                             kernel_size=(1, 1))
                    z[i] = mu[i] + sigma[i] * tf.random_normal(
                        tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            else:

                for j in reversed(range(0, i + 1)):

                    z_below_ups = layers.nearest_neighbour_upsample2D(
                        z_ups_mat[j + 1][i + 1], factor=2)
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_1' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_2' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)

                    z_ups_mat[j][i + 1] = z_below_ups

                if full_latent_dependencies:
                    z_input = tf.concat(
                        [pre_z[i + resolution_levels - latent_levels]] +
                        z_ups_mat[i][(i + 1):latent_levels],
                        axis=3,
                        name='concat_%d' % i)
                else:
                    z_input = tf.concat([
                        pre_z[i + resolution_levels - latent_levels],
                        z_ups_mat[i][i + 1]
                    ],
                                        axis=3,
                                        name='concat_%d' % i)

                z_input = layers.conv2D(z_input,
                                        'z%d_input_1' % i,
                                        num_filters=n0 * (i // 2 + 1),
                                        normalisation=norm,
                                        training=training)
                z_input = layers.conv2D(z_input,
                                        'z%d_input_2' % i,
                                        num_filters=n0 * (i // 2 + 1),
                                        normalisation=norm,
                                        training=training)

                mu[i] = layers.conv2D(z_input,
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity,
                                      kernel_size=(1, 1))

                if full_cov_list[i] == True:

                    l = layers.dense_layer(z_input,
                                           'z%d_sigma' % i,
                                           hidden_units=zdim_0 *
                                           spatial_cov_dim *
                                           (spatial_cov_dim + 1) // 2,
                                           activation=tf.identity)
                    l = tf.reshape(l, [
                        -1, zdim_0, spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2
                    ])
                    Lp = tf.contrib.distributions.fill_triangular(l)
                    L = tf.linalg.set_diag(
                        Lp, tf.nn.softplus(tf.linalg.diag_part(Lp)))

                    sigma[i] = L

                    eps = tf.random_normal(tf.shape(mu[i]))
                    eps = tf.transpose(eps, perm=[0, 3, 1, 2])
                    bs = tf.shape(x)[0]
                    eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1]))

                    eps_tmp = tf.matmul(sigma[i], eps)
                    eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1])
                    eps_tmp = tf.reshape(
                        eps_tmp,
                        [bs, spatial_zdim[0], spatial_zdim[1], zdim_0])

                    z[i] = mu[i] + eps_tmp

                else:

                    sigma[i] = layers.conv2D(z_input,
                                             'z%d_sigma' % i,
                                             num_filters=zdim_0,
                                             activation=tf.nn.softplus,
                                             kernel_size=(1, 1))
                    z[i] = mu[i] + sigma[i] * tf.random_normal(
                        tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            z_ups_mat[i][i] = z[i]

    return z, mu, sigma
Exemplo n.º 7
0
def segvae_const_latent(x,
                        s_oh,
                        zdim_0,
                        training,
                        scope_reuse=False,
                        norm=tfnorm.batch_norm,
                        **kwargs):

    n0 = kwargs.get('n0', 32)
    max_channel_power = kwargs.get('max_channel_power', 4)
    max_channels = n0 * 2**max_channel_power
    full_cov_list = kwargs.get('full_cov_list', None)

    resolution_levels = kwargs.get('resolution_levels', 5)

    def reduce_resolution(x, times, name):

        with tf.variable_scope(name):

            nett = x

            for ii in range(times):

                nett = layers.reshape_pool2D_layer(nett)
                nC = nett.get_shape().as_list()[3]
                nett = layers.conv2D(nett,
                                     'down_%d' % ii,
                                     num_filters=min(nC // 4, max_channels),
                                     normalisation=norm,
                                     training=training)

        return nett

    with tf.variable_scope('posterior') as scope:

        spatial_xdim = x.get_shape().as_list()[1:3]
        spatial_zdim = [d // 2**(resolution_levels - 1) for d in spatial_xdim]
        spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1]

        if scope_reuse:
            scope.reuse_variables()

        n0 = kwargs.get('n0', 32)
        levels = resolution_levels

        full_latent_dependencies = kwargs.get('full_latent_dependencies',
                                              False)

        pre_z = [None] * levels
        mu = [None] * levels
        sigma = [None] * levels
        z = [None] * levels

        z_mat = []
        for i in range(levels):
            z_mat.append(
                [None] *
                levels)  # encoding [original resolution][upsampled to]

        # Generate pre_z's
        for i in range(levels):

            if i == 0:
                net = tf.concat([x, s_oh - 0.5], axis=-1)
            else:
                net = layers.maxpool2D(pre_z[i - 1])

            net = layers.conv2D(net,
                                'z%d_pre_1' % i,
                                num_filters=n0 * (i // 2 + 1),
                                normalisation=norm,
                                training=training)
            pre_z[i] = net

        # Generate z's
        for i in reversed(range(levels)):

            z_input = reduce_resolution(pre_z[i],
                                        levels - i - 1,
                                        name='reduction_%d' % i)
            logging.info('z_input.shape')
            logging.info(z_input.get_shape().as_list())

            if i == levels - 1:

                mu[i] = layers.conv2D(z_input,
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity)

                if full_cov_list[i] == True:

                    l = layers.dense_layer(z_input,
                                           'z%d_sigma' % i,
                                           hidden_units=zdim_0 *
                                           spatial_cov_dim *
                                           (spatial_cov_dim + 1) // 2,
                                           activation=tf.identity)
                    l = tf.reshape(l, [
                        -1, zdim_0, spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2
                    ])
                    Lp = tf.contrib.distributions.fill_triangular(l)
                    L = tf.linalg.set_diag(
                        Lp, tf.nn.softplus(tf.linalg.diag_part(Lp))
                    )  # Cholesky factors must have positive diagonal

                    logging.info('L%d.shape ==========' % i)
                    logging.info(L.get_shape().as_list())

                    sigma[i] = L

                    eps = tf.random_normal(tf.shape(mu[i]))
                    eps = tf.transpose(eps, perm=[0, 3, 1, 2])
                    bs = tf.shape(x)[0]
                    eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1]))

                    eps_tmp = tf.matmul(sigma[i], eps)
                    eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1])
                    eps_tmp = tf.reshape(
                        eps_tmp,
                        [bs, spatial_zdim[0], spatial_zdim[1], zdim_0])

                    z[i] = mu[i] + eps_tmp

                else:

                    sigma[i] = layers.conv2D(z_input,
                                             'z%d_sigma' % i,
                                             num_filters=zdim_0,
                                             activation=tf.nn.softplus,
                                             kernel_size=(1, 1))
                    z[i] = mu[i] + sigma[i] * tf.random_normal(
                        tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            else:

                for j in reversed(range(0, i + 1)):
                    z_connect = layers.conv2D(z_mat[j + 1][i + 1],
                                              name='double_res_%d_to_%d' %
                                              ((i + 1), (j)),
                                              num_filters=2 * zdim_0,
                                              normalisation=norm,
                                              training=training)
                    z_mat[j][i + 1] = z_connect

                if full_latent_dependencies:
                    z_input = tf.concat([z_input] + z_mat[i][(i + 1):levels],
                                        axis=3,
                                        name='concat_%d' % i)
                else:
                    z_input = tf.concat([z_input, z_mat[i][(i + 1)]],
                                        axis=3,
                                        name='concat_%d' % i)

                mu[i] = layers.conv2D(z_input,
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity)

                if full_cov_list[i] == True:
                    l = layers.dense_layer(z_input,
                                           'z%d_sigma' % i,
                                           hidden_units=zdim_0 *
                                           spatial_cov_dim *
                                           (spatial_cov_dim + 1) // 2,
                                           activation=tf.identity)
                    l = tf.reshape(l, [
                        -1, zdim_0, spatial_cov_dim *
                        (spatial_cov_dim + 1) // 2
                    ])
                    Lp = tf.contrib.distributions.fill_triangular(l)
                    L = tf.linalg.set_diag(
                        Lp, tf.nn.softplus(tf.linalg.diag_part(Lp)))

                    sigma[i] = L

                    eps = tf.random_normal(tf.shape(mu[i]))
                    eps = tf.transpose(eps, perm=[0, 3, 1, 2])
                    bs = tf.shape(x)[0]
                    eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1]))

                    eps_tmp = tf.matmul(sigma[i], eps)
                    eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1])
                    eps_tmp = tf.reshape(
                        eps_tmp,
                        [bs, spatial_zdim[0], spatial_zdim[1], zdim_0])

                    z[i] = mu[i] + eps_tmp

                else:

                    sigma[i] = layers.conv2D(z_input,
                                             'z%d_sigma' % i,
                                             num_filters=zdim_0,
                                             activation=tf.nn.softplus)
                    z[i] = mu[i] + sigma[i] * tf.random_normal(
                        tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            z_mat[i][i] = z[i]

    return z, mu, sigma
Exemplo n.º 8
0
def proposed(z_list,
             training,
             image_size,
             n_classes,
             scope_reuse=False,
             norm=tfnorm.batch_norm,
             rank=10,
             diagonal=False,
             **kwargs):
    x = kwargs.get('x')

    resolution_levels = kwargs.get('resolution_levels', 7)
    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    conv_unit = layers.conv2D
    deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2)

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        net = dec[-1][-1]

        recomb = conv_unit(net,
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        epsilon = 1e-5

        mean = layers.conv2D(recomb,
                             'mean',
                             num_filters=n_classes,
                             kernel_size=(1, 1),
                             activation=tf.identity)
        log_cov_diag = layers.conv2D(recomb,
                                     'diag',
                                     num_filters=n_classes,
                                     kernel_size=(1, 1),
                                     activation=tf.identity)
        cov_factor = layers.conv2D(recomb,
                                   'factor',
                                   num_filters=n_classes * rank,
                                   kernel_size=(1, 1),
                                   activation=tf.identity)

        shape = image_size[:-1] + (n_classes, )
        flat_size = np.prod(shape)
        mean = tf.reshape(mean, [-1, flat_size])
        cov_diag = tf.exp(tf.reshape(log_cov_diag, [-1, flat_size])) + epsilon
        cov_factor = tf.reshape(cov_factor, [-1, flat_size, rank])
        if diagonal:
            dist = DiagonalMultivariateNormal(loc=mean, cov_diag=cov_diag)
        else:
            dist = LowRankMultivariateNormal(loc=mean,
                                             cov_diag=cov_diag,
                                             cov_factor=cov_factor)

        s = dist.rsample((1, ))
        s = tf.reshape(s, (-1, ) + shape)
        return [[s], dist]
Exemplo n.º 9
0
def phiseg(z_list,
           training,
           image_size,
           n_classes,
           scope_reuse=False,
           norm=tfnorm.batch_norm,
           **kwargs):
    """
    This is a U-NET like arch with skips before and after latent space and a rather simple decoder
    """

    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    def increase_resolution(x, times, num_filters, name):

        with tf.variable_scope(name):
            nett = x

            for i in range(times):
                nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2)
                nett = layers.conv2D(nett,
                                     'z%d_post' % i,
                                     num_filters=num_filters,
                                     normalisation=norm,
                                     training=training)

        return nett

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        resolution_levels = kwargs.get('resolution_levels', 7)
        latent_levels = kwargs.get('latent_levels', 5)
        lvl_diff = resolution_levels - latent_levels

        post_z = [None] * latent_levels
        post_c = [None] * latent_levels

        s = [None] * latent_levels

        # Generate post_z
        for i in range(latent_levels):
            net = layers.conv2D(z_list[i],
                                'z%d_post_1' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_post_2' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = increase_resolution(net,
                                      resolution_levels - latent_levels,
                                      num_filters=num_channels[i],
                                      name='preups_%d' % i)

            post_z[i] = net

        # Upstream path
        post_c[latent_levels - 1] = post_z[latent_levels - 1]

        for i in reversed(range(latent_levels - 1)):
            ups_below = layers.bilinear_upsample2D(post_c[i + 1],
                                                   name='post_z%d_ups' %
                                                   (i + 1),
                                                   factor=2)
            ups_below = layers.conv2D(ups_below,
                                      'post_z%d_ups_c' % (i + 1),
                                      num_filters=num_channels[i],
                                      normalisation=norm,
                                      training=training)

            concat = tf.concat([post_z[i], ups_below],
                               axis=3,
                               name='concat_%d' % i)

            net = layers.conv2D(concat,
                                'post_c_%d_1' % i,
                                num_filters=num_channels[i + lvl_diff],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'post_c_%d_2' % i,
                                num_filters=num_channels[i + lvl_diff],
                                normalisation=norm,
                                training=training)

            post_c[i] = net

        # Outputs
        for i in range(latent_levels):
            s_in = layers.conv2D(post_c[i],
                                 'y_lvl%d' % i,
                                 num_filters=n_classes,
                                 kernel_size=(1, 1),
                                 activation=tf.identity)
            s[i] = tf.image.resize_images(
                s_in,
                image_size[0:2],
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        return s
Exemplo n.º 10
0
def prob_unet2D(z_list,
                training,
                image_size,
                n_classes,
                scope_reuse=False,
                norm=tfnorm.batch_norm,
                **kwargs):
    x = kwargs.get('x')

    z = z_list[0]

    resolution_levels = kwargs.get('resolution_levels', 7)
    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    conv_unit = layers.conv2D
    deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2)

    bs = tf.shape(x)[0]
    zdim = z.get_shape().as_list()[-1]

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        z_t = tf.reshape(z, tf.stack((bs, 1, 1, zdim)))

        broadcast_z = tf.tile(z_t, (1, image_size[0], image_size[1], 1))

        net = tf.concat([dec[-1][-1], broadcast_z], axis=-1)

        recomb = conv_unit(net,
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        s = [
            layers.conv2D(recomb,
                          'prediction',
                          num_filters=n_classes,
                          kernel_size=(1, 1),
                          activation=tf.identity)
        ]

        return s
Exemplo n.º 11
0
def prob_unet2D_arch(
        x,
        training,
        nlabels,
        n0=32,
        resolution_levels=7,
        norm=tfnorm.batch_norm,
        conv_unit=layers.conv2D,
        deconv_unit=lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2),
        scope_reuse=False,
        return_net=False):

    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        recomb = conv_unit(dec[-1][-1],
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        s = layers.conv2D(recomb,
                          'prediction',
                          num_filters=nlabels,
                          kernel_size=(1, 1),
                          activation=tf.identity)

        return s
Exemplo n.º 12
0
def phiseg(x,
           s_oh,
           zdim_0,
           training,
           scope_reuse=False,
           norm=tfnorm.batch_norm,
           **kwargs):

    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('posterior') as scope:

        if scope_reuse:
            scope.reuse_variables()

        full_cov_list = kwargs.get('full_cov_list', None)

        n0 = kwargs.get('n0', 32)
        latent_levels = kwargs.get('latent_levels', 5)
        resolution_levels = kwargs.get('resolution_levels', 7)

        spatial_xdim = x.get_shape().as_list()[1:3]

        pre_z = [None] * resolution_levels

        mu = [None] * latent_levels
        sigma = [None] * latent_levels
        z = [None] * latent_levels

        z_ups_mat = []
        for i in range(latent_levels):
            z_ups_mat.append(
                [None] *
                latent_levels)  # encoding [original resolution][upsampled to]

        # Generate pre_z's
        for i in range(resolution_levels):

            if i == 0:
                net = tf.concat([x, s_oh - 0.5], axis=-1)
            else:
                net = layers.averagepool2D(pre_z[i - 1])

            net = layers.conv2D(net,
                                'z%d_pre_1' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_2' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_pre_3' % i,
                                num_filters=num_channels[i],
                                normalisation=norm,
                                training=training)

            pre_z[i] = net

        # Generate z's
        for i in reversed(range(latent_levels)):

            spatial_zdim = [
                d // 2**(i + resolution_levels - latent_levels)
                for d in spatial_xdim
            ]
            spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1]

            if i == latent_levels - 1:

                mu[i] = layers.conv2D(pre_z[i + resolution_levels -
                                            latent_levels],
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity)

                sigma[i] = layers.conv2D(pre_z[i + resolution_levels -
                                               latent_levels],
                                         'z%d_sigma' % i,
                                         num_filters=zdim_0,
                                         activation=tf.nn.softplus,
                                         kernel_size=(1, 1))
                z[i] = mu[i] + sigma[i] * tf.random_normal(
                    tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            else:

                for j in reversed(range(0, i + 1)):

                    z_below_ups = layers.bilinear_upsample2D(
                        z_ups_mat[j + 1][i + 1], factor=2, name='ups')
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_1' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)
                    z_below_ups = layers.conv2D(z_below_ups,
                                                name='z%d_ups_to_%d_c_2' %
                                                ((i + 1), (j + 1)),
                                                num_filters=zdim_0 * n0,
                                                normalisation=norm,
                                                training=training)

                    z_ups_mat[j][i + 1] = z_below_ups

                z_input = tf.concat([
                    pre_z[i + resolution_levels - latent_levels],
                    z_ups_mat[i][i + 1]
                ],
                                    axis=3,
                                    name='concat_%d' % i)

                z_input = layers.conv2D(z_input,
                                        'z%d_input_1' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)
                z_input = layers.conv2D(z_input,
                                        'z%d_input_2' % i,
                                        num_filters=num_channels[i],
                                        normalisation=norm,
                                        training=training)

                mu[i] = layers.conv2D(z_input,
                                      'z%d_mu' % i,
                                      num_filters=zdim_0,
                                      activation=tf.identity,
                                      kernel_size=(1, 1))

                sigma[i] = layers.conv2D(z_input,
                                         'z%d_sigma' % i,
                                         num_filters=zdim_0,
                                         activation=tf.nn.softplus,
                                         kernel_size=(1, 1))
                z[i] = mu[i] + sigma[i] * tf.random_normal(
                    tf.shape(mu[i]), 0, 1, dtype=tf.float32)

            z_ups_mat[i][i] = z[i]

    return z, mu, sigma
Exemplo n.º 13
0
def segvae_const_latent(z_list,
                        training,
                        image_size,
                        n_classes,
                        scope_reuse=False,
                        norm=tfnorm.batch_norm,
                        **kwargs):
    """
    This is a U-NET like arch with skips before and after latent space and a rather simple decoder
    """

    n0 = kwargs.get('n0', 32)
    max_channel_power = kwargs.get('max_channel_power', 4)
    max_channels = n0 * 2**max_channel_power

    def increase_resolution(x, times, name):

        with tf.variable_scope(name):
            nett = x

            for i in range(times):

                nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2)
                nC = nett.get_shape().as_list()[3]
                nett = layers.conv2D(nett,
                                     'z%d_post' % i,
                                     num_filters=nC,
                                     normalisation=norm,
                                     training=training)

        return nett

    n_channels = image_size[2]

    resolution_levels = kwargs.get('resolution_levels', 3)
    n0 = kwargs.get('n0', 32)

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        z_list_c = []
        pre_out = [None] * resolution_levels
        s = [None] * resolution_levels

        for i in range(resolution_levels):

            z_list_c.append(
                layers.conv2D(z_list[i],
                              'z%d_post' % i,
                              num_filters=n0 * (i // 2 + 1),
                              normalisation=norm,
                              training=training))

        pre_out[resolution_levels - 1] = z_list_c[resolution_levels - 1]

        for i in reversed(range(resolution_levels - 1)):

            top = increase_resolution(z_list_c[i], resolution_levels - i - 1,
                                      'upsample_top_%d' % i)
            bottom = increase_resolution(pre_out[i + 1], 1,
                                         'upsample_bottom_%d' % i)
            net = tf.concat([top, bottom], axis=3)
            pre_out[i] = layers.conv2D(net,
                                       'preout_%d' % i,
                                       num_filters=n0 * (i // 2 + 1),
                                       normalisation=norm,
                                       training=training)

        for i in range(resolution_levels):
            s_in = layers.conv2D(pre_out[i],
                                 'y_lvl%d' % i,
                                 num_filters=n_classes,
                                 kernel_size=(1, 1),
                                 activation=tf.identity)
            s[i] = tf.image.resize_images(
                s_in,
                image_size[0:2],
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        return s
Exemplo n.º 14
0
def unet_T_L_noconcat(z_list,
                      training,
                      image_size,
                      n_classes,
                      scope_reuse=False,
                      norm=tfnorm.batch_norm,
                      **kwargs):
    """
    This is a U-NET like arch with skips before and after latent space and a rather simple decoder
    """

    n0 = kwargs.get('n0', 32)
    max_channel_power = kwargs.get('max_channel_power', 4)
    max_channels = n0 * 2**max_channel_power

    def increase_resolution(x, times, name):

        with tf.variable_scope(name):
            nett = x

            for i in range(times):
                nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2)
                nC = nett.get_shape().as_list()[3]
                nett = layers.conv2D(nett,
                                     'z%d_post' % i,
                                     num_filters=min(nC * 2, max_channels),
                                     normalisation=norm,
                                     training=training)

        return nett

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        resolution_levels = kwargs.get('resolution_levels', 6)
        latent_levels = kwargs.get('latent_levels', 3)

        n0 = kwargs.get('n0', 32)

        post_z = [None] * latent_levels
        post_c = [None] * latent_levels

        s = [None] * latent_levels

        # Generate post_z
        for i in range(latent_levels):
            net = layers.conv2D(z_list[i],
                                'z%d_post_1' % i,
                                num_filters=n0 * (i // 2 + 1),
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'z%d_post_2' % i,
                                num_filters=n0 * (i // 2 + 1),
                                normalisation=norm,
                                training=training)
            net = increase_resolution(net,
                                      resolution_levels - latent_levels,
                                      name='preups_%d' % i)

            post_z[i] = net

        # Upstream path
        post_c[latent_levels - 1] = post_z[latent_levels - 1]

        for i in reversed(range(latent_levels - 1)):

            concat = post_z[i]

            net = layers.conv2D(concat,
                                'post_c_%d_1' % i,
                                num_filters=n0 * (i // 2 + 1),
                                normalisation=norm,
                                training=training)
            net = layers.conv2D(net,
                                'post_c_%d_2' % i,
                                num_filters=n0 * (i // 2 + 1),
                                normalisation=norm,
                                training=training)

            post_c[i] = net

        # Outputs
        for i in range(latent_levels):
            s_in = layers.conv2D(post_c[i],
                                 'y_lvl%d' % i,
                                 num_filters=n_classes,
                                 kernel_size=(1, 1),
                                 activation=tf.identity)
            s[i] = tf.image.resize_images(
                s_in,
                image_size[0:2],
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

        return s
Exemplo n.º 15
0
def unet2D(x,
           training,
           nlabels,
           n0=64,
           resolution_levels=5,
           norm=tfnorm.batch_norm,
           conv_unit=layers.conv2D,
           deconv_unit=layers.transposed_conv2D,
           simplified_dec=False,
           scope_reuse=False,
           return_net=False):

    with tf.variable_scope('segmenter') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.maxpool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=n0 * (2**ii),
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=n0 * (2**ii),
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                if simplified_dec:
                    num_transconv_filters = nlabels
                else:
                    num_transconv_filters = n0 * (2**(ii - 1))

                dec[jj].append(
                    deconv_unit(next_inp,
                                name='upconv_%d' % jj,
                                num_filters=num_transconv_filters,
                                training=training,
                                normalisation=norm,
                                add_bias=add_bias))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=n0 * (2**(ii - 1)),
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias,
                              projection=True)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=n0 * (2**(ii - 1)),
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        output = layers.conv2D(dec[-1][-1],
                               'prediction',
                               num_filters=nlabels,
                               kernel_size=(1, 1),
                               activation=tf.identity,
                               training=training,
                               normalisation=norm,
                               add_bias=add_bias)

        dec[-1].append(output)

        if return_net:
            net = enc + dec
            return output, net

        return output