def betaVAE_bn(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): resolution_levels = kwargs.get('resolution_levels', 5) image_size = x.get_shape().as_list()[1:3] final_kernel_size = [s // (2**(resolution_levels - 1)) for s in image_size] # POSTERIOR #################### with tf.variable_scope('posterior') as scope: if scope_reuse: scope.reuse_variables() n0 = kwargs.get('n0', 32) mu_z = [] sigma_z = [] z = [] # Generate pre_z's net = tf.concat([x, s_oh - 0.5], axis=-1) for ii in range(resolution_levels - 1): net = layers.conv2D(net, 'q_z_%d' % ii, num_filters=n0 * (ii // 2 + 1), kernel_size=(4, 4), strides=(2, 2), normalisation=norm, training=training) net = layers.conv2D(net, 'q_z_%d' % resolution_levels, num_filters=n0 * 8, kernel_size=final_kernel_size, strides=(1, 1), padding='VALID', normalisation=norm, training=training) mu_z.append( layers.dense_layer(net, 'z_mu', hidden_units=zdim_0, activation=tf.identity)) sigma_z.append( layers.dense_layer(net, 'z_sigma', hidden_units=zdim_0, activation=tf.nn.softplus)) z.append(mu_z[0] + sigma_z[0] * tf.random_normal(tf.shape(mu_z[0]), 0, 1, dtype=tf.float32)) return z, mu_z, sigma_z
def id_res_unet2D(x, training, nlabels, n0=64, resolution_levels=5, norm=tfnorm.batch_norm, scope_reuse=False, return_net=False): add_bias = False if norm == tfnorm.batch_norm else True input_layer = layers.conv2D(x, training=training, num_filters=n0, name='input_layer', normalisation=norm, add_bias=add_bias) return unet2D(input_layer, training, nlabels, n0=n0, resolution_levels=resolution_levels, norm=norm, conv_unit=layers.identity_residual_unit2D, scope_reuse=scope_reuse, return_net=return_net)
def reduce_resolution(x, times, name): with tf.variable_scope(name): nett = x for ii in range(times): nett = layers.reshape_pool2D_layer(nett) nC = nett.get_shape().as_list()[3] nett = layers.conv2D(nett, 'down_%d' % ii, num_filters=min(nC//4, max_channels), normalisation=norm, training=training) return nett
def increase_resolution(x, times, num_filters, name): with tf.variable_scope(name): nett = x for i in range(times): nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=num_filters, normalisation=norm, training=training) return nett
def increase_resolution(x, times, name): with tf.variable_scope(name): nett = x for i in range(times): nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) nC = nett.get_shape().as_list()[3] nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=min(nC * 2, max_channels), normalisation=norm, training=training) return nett
def unet_T_L(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): # POSTERIOR #################### with tf.variable_scope('posterior') as scope: if scope_reuse: scope.reuse_variables() full_cov_list = kwargs.get('full_cov_list', None) n0 = kwargs.get('n0', 32) max_channel_power = kwargs.get('max_channel_power', 4) max_channels = n0 * 2**max_channel_power latent_levels = kwargs.get('latent_levels', 4) resolution_levels = kwargs.get('resolution_levels', 6) spatial_xdim = x.get_shape().as_list()[1:3] full_latent_dependencies = kwargs.get('full_latent_dependencies', False) pre_z = [None] * resolution_levels mu = [None] * latent_levels sigma = [None] * latent_levels z = [None] * latent_levels z_ups_mat = [] for i in range(latent_levels): z_ups_mat.append( [None] * latent_levels) # encoding [original resolution][upsampled to] # Generate pre_z's for i in range(resolution_levels): if i == 0: net = tf.concat([x, s_oh - 0.5], axis=-1) else: net = layers.reshape_pool2D_layer(pre_z[i - 1]) net = layers.conv2D(net, 'z%d_pre_1' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_pre_2' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) pre_z[i] = net # Generate z's for i in reversed(range(latent_levels)): spatial_zdim = [ d // 2**(i + resolution_levels - latent_levels) for d in spatial_xdim ] spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1] if i == latent_levels - 1: mu[i] = layers.conv2D(pre_z[i + resolution_levels - latent_levels], 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity) if full_cov_list[i] == True: l = layers.dense_layer( pre_z[i + resolution_levels - latent_levels], 'z%d_sigma' % i, hidden_units=zdim_0 * spatial_cov_dim * (spatial_cov_dim + 1) // 2, activation=tf.identity) l = tf.reshape(l, [ -1, zdim_0, spatial_cov_dim * (spatial_cov_dim + 1) // 2 ]) Lp = tf.contrib.distributions.fill_triangular(l) L = tf.linalg.set_diag( Lp, tf.nn.softplus(tf.linalg.diag_part(Lp)) ) # Cholesky factors must have positive diagonal sigma[i] = L eps = tf.random_normal(tf.shape(mu[i])) eps = tf.transpose(eps, perm=[0, 3, 1, 2]) bs = tf.shape(x)[0] eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1])) eps_tmp = tf.matmul(sigma[i], eps) eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1]) eps_tmp = tf.reshape( eps_tmp, [bs, spatial_zdim[0], spatial_zdim[1], zdim_0]) z[i] = mu[i] + eps_tmp else: sigma[i] = layers.conv2D(pre_z[i + resolution_levels - latent_levels], 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) else: for j in reversed(range(0, i + 1)): z_below_ups = layers.nearest_neighbour_upsample2D( z_ups_mat[j + 1][i + 1], factor=2) z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_1' % ((i + 1), (j + 1)), num_filters=zdim_0 * n0, normalisation=norm, training=training) z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_2' % ((i + 1), (j + 1)), num_filters=zdim_0 * n0, normalisation=norm, training=training) z_ups_mat[j][i + 1] = z_below_ups if full_latent_dependencies: z_input = tf.concat( [pre_z[i + resolution_levels - latent_levels]] + z_ups_mat[i][(i + 1):latent_levels], axis=3, name='concat_%d' % i) else: z_input = tf.concat([ pre_z[i + resolution_levels - latent_levels], z_ups_mat[i][i + 1] ], axis=3, name='concat_%d' % i) z_input = layers.conv2D(z_input, 'z%d_input_1' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) z_input = layers.conv2D(z_input, 'z%d_input_2' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) mu[i] = layers.conv2D(z_input, 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity, kernel_size=(1, 1)) if full_cov_list[i] == True: l = layers.dense_layer(z_input, 'z%d_sigma' % i, hidden_units=zdim_0 * spatial_cov_dim * (spatial_cov_dim + 1) // 2, activation=tf.identity) l = tf.reshape(l, [ -1, zdim_0, spatial_cov_dim * (spatial_cov_dim + 1) // 2 ]) Lp = tf.contrib.distributions.fill_triangular(l) L = tf.linalg.set_diag( Lp, tf.nn.softplus(tf.linalg.diag_part(Lp))) sigma[i] = L eps = tf.random_normal(tf.shape(mu[i])) eps = tf.transpose(eps, perm=[0, 3, 1, 2]) bs = tf.shape(x)[0] eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1])) eps_tmp = tf.matmul(sigma[i], eps) eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1]) eps_tmp = tf.reshape( eps_tmp, [bs, spatial_zdim[0], spatial_zdim[1], zdim_0]) z[i] = mu[i] + eps_tmp else: sigma[i] = layers.conv2D(z_input, 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) z_ups_mat[i][i] = z[i] return z, mu, sigma
def segvae_const_latent(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): n0 = kwargs.get('n0', 32) max_channel_power = kwargs.get('max_channel_power', 4) max_channels = n0 * 2**max_channel_power full_cov_list = kwargs.get('full_cov_list', None) resolution_levels = kwargs.get('resolution_levels', 5) def reduce_resolution(x, times, name): with tf.variable_scope(name): nett = x for ii in range(times): nett = layers.reshape_pool2D_layer(nett) nC = nett.get_shape().as_list()[3] nett = layers.conv2D(nett, 'down_%d' % ii, num_filters=min(nC // 4, max_channels), normalisation=norm, training=training) return nett with tf.variable_scope('posterior') as scope: spatial_xdim = x.get_shape().as_list()[1:3] spatial_zdim = [d // 2**(resolution_levels - 1) for d in spatial_xdim] spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1] if scope_reuse: scope.reuse_variables() n0 = kwargs.get('n0', 32) levels = resolution_levels full_latent_dependencies = kwargs.get('full_latent_dependencies', False) pre_z = [None] * levels mu = [None] * levels sigma = [None] * levels z = [None] * levels z_mat = [] for i in range(levels): z_mat.append( [None] * levels) # encoding [original resolution][upsampled to] # Generate pre_z's for i in range(levels): if i == 0: net = tf.concat([x, s_oh - 0.5], axis=-1) else: net = layers.maxpool2D(pre_z[i - 1]) net = layers.conv2D(net, 'z%d_pre_1' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) pre_z[i] = net # Generate z's for i in reversed(range(levels)): z_input = reduce_resolution(pre_z[i], levels - i - 1, name='reduction_%d' % i) logging.info('z_input.shape') logging.info(z_input.get_shape().as_list()) if i == levels - 1: mu[i] = layers.conv2D(z_input, 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity) if full_cov_list[i] == True: l = layers.dense_layer(z_input, 'z%d_sigma' % i, hidden_units=zdim_0 * spatial_cov_dim * (spatial_cov_dim + 1) // 2, activation=tf.identity) l = tf.reshape(l, [ -1, zdim_0, spatial_cov_dim * (spatial_cov_dim + 1) // 2 ]) Lp = tf.contrib.distributions.fill_triangular(l) L = tf.linalg.set_diag( Lp, tf.nn.softplus(tf.linalg.diag_part(Lp)) ) # Cholesky factors must have positive diagonal logging.info('L%d.shape ==========' % i) logging.info(L.get_shape().as_list()) sigma[i] = L eps = tf.random_normal(tf.shape(mu[i])) eps = tf.transpose(eps, perm=[0, 3, 1, 2]) bs = tf.shape(x)[0] eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1])) eps_tmp = tf.matmul(sigma[i], eps) eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1]) eps_tmp = tf.reshape( eps_tmp, [bs, spatial_zdim[0], spatial_zdim[1], zdim_0]) z[i] = mu[i] + eps_tmp else: sigma[i] = layers.conv2D(z_input, 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) else: for j in reversed(range(0, i + 1)): z_connect = layers.conv2D(z_mat[j + 1][i + 1], name='double_res_%d_to_%d' % ((i + 1), (j)), num_filters=2 * zdim_0, normalisation=norm, training=training) z_mat[j][i + 1] = z_connect if full_latent_dependencies: z_input = tf.concat([z_input] + z_mat[i][(i + 1):levels], axis=3, name='concat_%d' % i) else: z_input = tf.concat([z_input, z_mat[i][(i + 1)]], axis=3, name='concat_%d' % i) mu[i] = layers.conv2D(z_input, 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity) if full_cov_list[i] == True: l = layers.dense_layer(z_input, 'z%d_sigma' % i, hidden_units=zdim_0 * spatial_cov_dim * (spatial_cov_dim + 1) // 2, activation=tf.identity) l = tf.reshape(l, [ -1, zdim_0, spatial_cov_dim * (spatial_cov_dim + 1) // 2 ]) Lp = tf.contrib.distributions.fill_triangular(l) L = tf.linalg.set_diag( Lp, tf.nn.softplus(tf.linalg.diag_part(Lp))) sigma[i] = L eps = tf.random_normal(tf.shape(mu[i])) eps = tf.transpose(eps, perm=[0, 3, 1, 2]) bs = tf.shape(x)[0] eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1])) eps_tmp = tf.matmul(sigma[i], eps) eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1]) eps_tmp = tf.reshape( eps_tmp, [bs, spatial_zdim[0], spatial_zdim[1], zdim_0]) z[i] = mu[i] + eps_tmp else: sigma[i] = layers.conv2D(z_input, 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) z_mat[i][i] = z[i] return z, mu, sigma
def proposed(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, rank=10, diagonal=False, **kwargs): x = kwargs.get('x') resolution_levels = kwargs.get('resolution_levels', 7) n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] conv_unit = layers.conv2D deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2) with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) net = dec[-1][-1] recomb = conv_unit(net, 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) epsilon = 1e-5 mean = layers.conv2D(recomb, 'mean', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) log_cov_diag = layers.conv2D(recomb, 'diag', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) cov_factor = layers.conv2D(recomb, 'factor', num_filters=n_classes * rank, kernel_size=(1, 1), activation=tf.identity) shape = image_size[:-1] + (n_classes, ) flat_size = np.prod(shape) mean = tf.reshape(mean, [-1, flat_size]) cov_diag = tf.exp(tf.reshape(log_cov_diag, [-1, flat_size])) + epsilon cov_factor = tf.reshape(cov_factor, [-1, flat_size, rank]) if diagonal: dist = DiagonalMultivariateNormal(loc=mean, cov_diag=cov_diag) else: dist = LowRankMultivariateNormal(loc=mean, cov_diag=cov_diag, cov_factor=cov_factor) s = dist.rsample((1, )) s = tf.reshape(s, (-1, ) + shape) return [[s], dist]
def phiseg(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): """ This is a U-NET like arch with skips before and after latent space and a rather simple decoder """ n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] def increase_resolution(x, times, num_filters, name): with tf.variable_scope(name): nett = x for i in range(times): nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=num_filters, normalisation=norm, training=training) return nett with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() resolution_levels = kwargs.get('resolution_levels', 7) latent_levels = kwargs.get('latent_levels', 5) lvl_diff = resolution_levels - latent_levels post_z = [None] * latent_levels post_c = [None] * latent_levels s = [None] * latent_levels # Generate post_z for i in range(latent_levels): net = layers.conv2D(z_list[i], 'z%d_post_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_post_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = increase_resolution(net, resolution_levels - latent_levels, num_filters=num_channels[i], name='preups_%d' % i) post_z[i] = net # Upstream path post_c[latent_levels - 1] = post_z[latent_levels - 1] for i in reversed(range(latent_levels - 1)): ups_below = layers.bilinear_upsample2D(post_c[i + 1], name='post_z%d_ups' % (i + 1), factor=2) ups_below = layers.conv2D(ups_below, 'post_z%d_ups_c' % (i + 1), num_filters=num_channels[i], normalisation=norm, training=training) concat = tf.concat([post_z[i], ups_below], axis=3, name='concat_%d' % i) net = layers.conv2D(concat, 'post_c_%d_1' % i, num_filters=num_channels[i + lvl_diff], normalisation=norm, training=training) net = layers.conv2D(net, 'post_c_%d_2' % i, num_filters=num_channels[i + lvl_diff], normalisation=norm, training=training) post_c[i] = net # Outputs for i in range(latent_levels): s_in = layers.conv2D(post_c[i], 'y_lvl%d' % i, num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) s[i] = tf.image.resize_images( s_in, image_size[0:2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) return s
def prob_unet2D(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): x = kwargs.get('x') z = z_list[0] resolution_levels = kwargs.get('resolution_levels', 7) n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] conv_unit = layers.conv2D deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2) bs = tf.shape(x)[0] zdim = z.get_shape().as_list()[-1] with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) z_t = tf.reshape(z, tf.stack((bs, 1, 1, zdim))) broadcast_z = tf.tile(z_t, (1, image_size[0], image_size[1], 1)) net = tf.concat([dec[-1][-1], broadcast_z], axis=-1) recomb = conv_unit(net, 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) s = [ layers.conv2D(recomb, 'prediction', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) ] return s
def prob_unet2D_arch( x, training, nlabels, n0=32, resolution_levels=7, norm=tfnorm.batch_norm, conv_unit=layers.conv2D, deconv_unit=lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2), scope_reuse=False, return_net=False): num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) recomb = conv_unit(dec[-1][-1], 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) s = layers.conv2D(recomb, 'prediction', num_filters=nlabels, kernel_size=(1, 1), activation=tf.identity) return s
def phiseg(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] with tf.variable_scope('posterior') as scope: if scope_reuse: scope.reuse_variables() full_cov_list = kwargs.get('full_cov_list', None) n0 = kwargs.get('n0', 32) latent_levels = kwargs.get('latent_levels', 5) resolution_levels = kwargs.get('resolution_levels', 7) spatial_xdim = x.get_shape().as_list()[1:3] pre_z = [None] * resolution_levels mu = [None] * latent_levels sigma = [None] * latent_levels z = [None] * latent_levels z_ups_mat = [] for i in range(latent_levels): z_ups_mat.append( [None] * latent_levels) # encoding [original resolution][upsampled to] # Generate pre_z's for i in range(resolution_levels): if i == 0: net = tf.concat([x, s_oh - 0.5], axis=-1) else: net = layers.averagepool2D(pre_z[i - 1]) net = layers.conv2D(net, 'z%d_pre_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_pre_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_pre_3' % i, num_filters=num_channels[i], normalisation=norm, training=training) pre_z[i] = net # Generate z's for i in reversed(range(latent_levels)): spatial_zdim = [ d // 2**(i + resolution_levels - latent_levels) for d in spatial_xdim ] spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1] if i == latent_levels - 1: mu[i] = layers.conv2D(pre_z[i + resolution_levels - latent_levels], 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity) sigma[i] = layers.conv2D(pre_z[i + resolution_levels - latent_levels], 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) else: for j in reversed(range(0, i + 1)): z_below_ups = layers.bilinear_upsample2D( z_ups_mat[j + 1][i + 1], factor=2, name='ups') z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_1' % ((i + 1), (j + 1)), num_filters=zdim_0 * n0, normalisation=norm, training=training) z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_2' % ((i + 1), (j + 1)), num_filters=zdim_0 * n0, normalisation=norm, training=training) z_ups_mat[j][i + 1] = z_below_ups z_input = tf.concat([ pre_z[i + resolution_levels - latent_levels], z_ups_mat[i][i + 1] ], axis=3, name='concat_%d' % i) z_input = layers.conv2D(z_input, 'z%d_input_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) z_input = layers.conv2D(z_input, 'z%d_input_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) mu[i] = layers.conv2D(z_input, 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity, kernel_size=(1, 1)) sigma[i] = layers.conv2D(z_input, 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) z_ups_mat[i][i] = z[i] return z, mu, sigma
def segvae_const_latent(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): """ This is a U-NET like arch with skips before and after latent space and a rather simple decoder """ n0 = kwargs.get('n0', 32) max_channel_power = kwargs.get('max_channel_power', 4) max_channels = n0 * 2**max_channel_power def increase_resolution(x, times, name): with tf.variable_scope(name): nett = x for i in range(times): nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) nC = nett.get_shape().as_list()[3] nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=nC, normalisation=norm, training=training) return nett n_channels = image_size[2] resolution_levels = kwargs.get('resolution_levels', 3) n0 = kwargs.get('n0', 32) with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() z_list_c = [] pre_out = [None] * resolution_levels s = [None] * resolution_levels for i in range(resolution_levels): z_list_c.append( layers.conv2D(z_list[i], 'z%d_post' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training)) pre_out[resolution_levels - 1] = z_list_c[resolution_levels - 1] for i in reversed(range(resolution_levels - 1)): top = increase_resolution(z_list_c[i], resolution_levels - i - 1, 'upsample_top_%d' % i) bottom = increase_resolution(pre_out[i + 1], 1, 'upsample_bottom_%d' % i) net = tf.concat([top, bottom], axis=3) pre_out[i] = layers.conv2D(net, 'preout_%d' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) for i in range(resolution_levels): s_in = layers.conv2D(pre_out[i], 'y_lvl%d' % i, num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) s[i] = tf.image.resize_images( s_in, image_size[0:2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) return s
def unet_T_L_noconcat(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): """ This is a U-NET like arch with skips before and after latent space and a rather simple decoder """ n0 = kwargs.get('n0', 32) max_channel_power = kwargs.get('max_channel_power', 4) max_channels = n0 * 2**max_channel_power def increase_resolution(x, times, name): with tf.variable_scope(name): nett = x for i in range(times): nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) nC = nett.get_shape().as_list()[3] nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=min(nC * 2, max_channels), normalisation=norm, training=training) return nett with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() resolution_levels = kwargs.get('resolution_levels', 6) latent_levels = kwargs.get('latent_levels', 3) n0 = kwargs.get('n0', 32) post_z = [None] * latent_levels post_c = [None] * latent_levels s = [None] * latent_levels # Generate post_z for i in range(latent_levels): net = layers.conv2D(z_list[i], 'z%d_post_1' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_post_2' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) net = increase_resolution(net, resolution_levels - latent_levels, name='preups_%d' % i) post_z[i] = net # Upstream path post_c[latent_levels - 1] = post_z[latent_levels - 1] for i in reversed(range(latent_levels - 1)): concat = post_z[i] net = layers.conv2D(concat, 'post_c_%d_1' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) net = layers.conv2D(net, 'post_c_%d_2' % i, num_filters=n0 * (i // 2 + 1), normalisation=norm, training=training) post_c[i] = net # Outputs for i in range(latent_levels): s_in = layers.conv2D(post_c[i], 'y_lvl%d' % i, num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) s[i] = tf.image.resize_images( s_in, image_size[0:2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) return s
def unet2D(x, training, nlabels, n0=64, resolution_levels=5, norm=tfnorm.batch_norm, conv_unit=layers.conv2D, deconv_unit=layers.transposed_conv2D, simplified_dec=False, scope_reuse=False, return_net=False): with tf.variable_scope('segmenter') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.maxpool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=n0 * (2**ii), training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=n0 * (2**ii), training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] if simplified_dec: num_transconv_filters = nlabels else: num_transconv_filters = n0 * (2**(ii - 1)) dec[jj].append( deconv_unit(next_inp, name='upconv_%d' % jj, num_filters=num_transconv_filters, training=training, normalisation=norm, add_bias=add_bias)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=n0 * (2**(ii - 1)), training=training, normalisation=norm, add_bias=add_bias, projection=True) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=n0 * (2**(ii - 1)), training=training, normalisation=norm, add_bias=add_bias)) output = layers.conv2D(dec[-1][-1], 'prediction', num_filters=nlabels, kernel_size=(1, 1), activation=tf.identity, training=training, normalisation=norm, add_bias=add_bias) dec[-1].append(output) if return_net: net = enc + dec return output, net return output