def increase_resolution(x, times, num_filters, name): with tf.variable_scope(name): nett = x for i in range(times): nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=num_filters, normalisation=norm, training=training) return nett
def resize_features(features, size, name): for f in range(len(features)): this_feature = features[f] this_feature_resized = layers.bilinear_upsample2D(this_feature, size, name + str(f)) if f is 0: features_resized = this_feature_resized else: features_resized = tf.concat((features_resized, this_feature_resized), axis=-1) return features_resized
def increase_resolution(x, times, name): with tf.variable_scope(name): nett = x for i in range(times): nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) nC = nett.get_shape().as_list()[3] nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=min(nC * 2, max_channels), normalisation=norm, training=training) return nett
def hybrid(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] with tf.variable_scope('posterior') as scope: if scope_reuse: scope.reuse_variables() full_cov_list = kwargs.get('full_cov_list', None) n0 = kwargs.get('n0', 32) latent_levels = kwargs.get('latent_levels', 5) resolution_levels = kwargs.get('resolution_levels', 7) spatial_xdim = x.get_shape().as_list()[1:3] full_latent_dependencies = kwargs.get('full_latent_dependencies', False) pre_z = [None] * resolution_levels mu = [None] * latent_levels sigma = [None] * latent_levels z = [None] * latent_levels z_ups_mat = [] for i in range(latent_levels): z_ups_mat.append( [None] * latent_levels) # encoding [original resolution][upsampled to] # Generate pre_z's for i in range(resolution_levels): if i == 0: net = tf.concat([x, s_oh - 0.5], axis=-1) else: net = layers.averagepool2D(pre_z[i - 1]) net = layers.conv2D(net, 'z%d_pre_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_pre_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_pre_3' % i, num_filters=num_channels[i], normalisation=norm, training=training) pre_z[i] = net # Generate z's for i in reversed(range(latent_levels)): spatial_zdim = [ d // 2**(i + resolution_levels - latent_levels) for d in spatial_xdim ] spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1] if i == latent_levels - 1: mu[i] = layers.conv2D(pre_z[i + resolution_levels - latent_levels], 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity) if full_cov_list[i] == True: l = layers.dense_layer( pre_z[i + resolution_levels - latent_levels], 'z%d_sigma' % i, hidden_units=zdim_0 * spatial_cov_dim * (spatial_cov_dim + 1) // 2, activation=tf.identity) l = tf.reshape(l, [ -1, zdim_0, spatial_cov_dim * (spatial_cov_dim + 1) // 2 ]) Lp = tf.contrib.distributions.fill_triangular(l) L = tf.linalg.set_diag( Lp, tf.nn.softplus(tf.linalg.diag_part(Lp)) ) # Cholesky factors must have positive diagonal sigma[i] = L eps = tf.random_normal(tf.shape(mu[i])) eps = tf.transpose(eps, perm=[0, 3, 1, 2]) bs = tf.shape(x)[0] eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1])) eps_tmp = tf.matmul(sigma[i], eps) eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1]) eps_tmp = tf.reshape( eps_tmp, [bs, spatial_zdim[0], spatial_zdim[1], zdim_0]) z[i] = mu[i] + eps_tmp else: sigma[i] = layers.conv2D(pre_z[i + resolution_levels - latent_levels], 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) else: for j in reversed(range(0, i + 1)): z_below_ups = layers.bilinear_upsample2D( z_ups_mat[j + 1][i + 1], factor=2, name='ups') z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_1' % ((i + 1), (j + 1)), num_filters=zdim_0 * n0, normalisation=norm, training=training) z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_2' % ((i + 1), (j + 1)), num_filters=zdim_0 * n0, normalisation=norm, training=training) z_ups_mat[j][i + 1] = z_below_ups if full_latent_dependencies: z_input = tf.concat( [pre_z[i + resolution_levels - latent_levels]] + z_ups_mat[i][(i + 1):latent_levels], axis=3, name='concat_%d' % i) else: z_input = tf.concat([ pre_z[i + resolution_levels - latent_levels], z_ups_mat[i][i + 1] ], axis=3, name='concat_%d' % i) z_input = layers.conv2D(z_input, 'z%d_input_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) z_input = layers.conv2D(z_input, 'z%d_input_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) mu[i] = layers.conv2D(z_input, 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity, kernel_size=(1, 1)) if full_cov_list[i] == True: l = layers.dense_layer(z_input, 'z%d_sigma' % i, hidden_units=zdim_0 * spatial_cov_dim * (spatial_cov_dim + 1) // 2, activation=tf.identity) l = tf.reshape(l, [ -1, zdim_0, spatial_cov_dim * (spatial_cov_dim + 1) // 2 ]) Lp = tf.contrib.distributions.fill_triangular(l) L = tf.linalg.set_diag( Lp, tf.nn.softplus(tf.linalg.diag_part(Lp))) sigma[i] = L eps = tf.random_normal(tf.shape(mu[i])) eps = tf.transpose(eps, perm=[0, 3, 1, 2]) bs = tf.shape(x)[0] eps = tf.reshape(eps, tf.stack([bs, zdim_0, -1, 1])) eps_tmp = tf.matmul(sigma[i], eps) eps_tmp = tf.transpose(eps_tmp, perm=[0, 2, 3, 1]) eps_tmp = tf.reshape( eps_tmp, [bs, spatial_zdim[0], spatial_zdim[1], zdim_0]) z[i] = mu[i] + eps_tmp else: sigma[i] = layers.conv2D(z_input, 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) z_ups_mat[i][i] = z[i] return z, mu, sigma
def proposed(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, rank=10, diagonal=False, **kwargs): x = kwargs.get('x') resolution_levels = kwargs.get('resolution_levels', 7) n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] conv_unit = layers.conv2D deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2) with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) net = dec[-1][-1] recomb = conv_unit(net, 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) epsilon = 1e-5 mean = layers.conv2D(recomb, 'mean', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) log_cov_diag = layers.conv2D(recomb, 'diag', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) cov_factor = layers.conv2D(recomb, 'factor', num_filters=n_classes * rank, kernel_size=(1, 1), activation=tf.identity) shape = image_size[:-1] + (n_classes, ) flat_size = np.prod(shape) mean = tf.reshape(mean, [-1, flat_size]) cov_diag = tf.exp(tf.reshape(log_cov_diag, [-1, flat_size])) + epsilon cov_factor = tf.reshape(cov_factor, [-1, flat_size, rank]) if diagonal: dist = DiagonalMultivariateNormal(loc=mean, cov_diag=cov_diag) else: dist = LowRankMultivariateNormal(loc=mean, cov_diag=cov_diag, cov_factor=cov_factor) s = dist.rsample((1, )) s = tf.reshape(s, (-1, ) + shape) return [[s], dist]
def phiseg(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): """ This is a U-NET like arch with skips before and after latent space and a rather simple decoder """ n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] def increase_resolution(x, times, num_filters, name): with tf.variable_scope(name): nett = x for i in range(times): nett = layers.bilinear_upsample2D(nett, 'ups_%d' % i, 2) nett = layers.conv2D(nett, 'z%d_post' % i, num_filters=num_filters, normalisation=norm, training=training) return nett with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() resolution_levels = kwargs.get('resolution_levels', 7) latent_levels = kwargs.get('latent_levels', 5) lvl_diff = resolution_levels - latent_levels post_z = [None] * latent_levels post_c = [None] * latent_levels s = [None] * latent_levels # Generate post_z for i in range(latent_levels): net = layers.conv2D(z_list[i], 'z%d_post_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_post_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = increase_resolution(net, resolution_levels - latent_levels, num_filters=num_channels[i], name='preups_%d' % i) post_z[i] = net # Upstream path post_c[latent_levels - 1] = post_z[latent_levels - 1] for i in reversed(range(latent_levels - 1)): ups_below = layers.bilinear_upsample2D(post_c[i + 1], name='post_z%d_ups' % (i + 1), factor=2) ups_below = layers.conv2D(ups_below, 'post_z%d_ups_c' % (i + 1), num_filters=num_channels[i], normalisation=norm, training=training) concat = tf.concat([post_z[i], ups_below], axis=3, name='concat_%d' % i) net = layers.conv2D(concat, 'post_c_%d_1' % i, num_filters=num_channels[i + lvl_diff], normalisation=norm, training=training) net = layers.conv2D(net, 'post_c_%d_2' % i, num_filters=num_channels[i + lvl_diff], normalisation=norm, training=training) post_c[i] = net # Outputs for i in range(latent_levels): s_in = layers.conv2D(post_c[i], 'y_lvl%d' % i, num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) s[i] = tf.image.resize_images( s_in, image_size[0:2], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) return s
def prob_unet2D(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): x = kwargs.get('x') z = z_list[0] resolution_levels = kwargs.get('resolution_levels', 7) n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] conv_unit = layers.conv2D deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2) bs = tf.shape(x)[0] zdim = z.get_shape().as_list()[-1] with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) z_t = tf.reshape(z, tf.stack((bs, 1, 1, zdim))) broadcast_z = tf.tile(z_t, (1, image_size[0], image_size[1], 1)) net = tf.concat([dec[-1][-1], broadcast_z], axis=-1) recomb = conv_unit(net, 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) s = [ layers.conv2D(recomb, 'prediction', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) ] return s
def prob_unet2D_arch( x, training, nlabels, n0=32, resolution_levels=7, norm=tfnorm.batch_norm, conv_unit=layers.conv2D, deconv_unit=lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2), scope_reuse=False, return_net=False): num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) recomb = conv_unit(dec[-1][-1], 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) s = layers.conv2D(recomb, 'prediction', num_filters=nlabels, kernel_size=(1, 1), activation=tf.identity) return s
def phiseg(x, s_oh, zdim_0, training, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] with tf.variable_scope('posterior') as scope: if scope_reuse: scope.reuse_variables() full_cov_list = kwargs.get('full_cov_list', None) n0 = kwargs.get('n0', 32) latent_levels = kwargs.get('latent_levels', 5) resolution_levels = kwargs.get('resolution_levels', 7) spatial_xdim = x.get_shape().as_list()[1:3] pre_z = [None] * resolution_levels mu = [None] * latent_levels sigma = [None] * latent_levels z = [None] * latent_levels z_ups_mat = [] for i in range(latent_levels): z_ups_mat.append( [None] * latent_levels) # encoding [original resolution][upsampled to] # Generate pre_z's for i in range(resolution_levels): if i == 0: net = tf.concat([x, s_oh - 0.5], axis=-1) else: net = layers.averagepool2D(pre_z[i - 1]) net = layers.conv2D(net, 'z%d_pre_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_pre_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) net = layers.conv2D(net, 'z%d_pre_3' % i, num_filters=num_channels[i], normalisation=norm, training=training) pre_z[i] = net # Generate z's for i in reversed(range(latent_levels)): spatial_zdim = [ d // 2**(i + resolution_levels - latent_levels) for d in spatial_xdim ] spatial_cov_dim = spatial_zdim[0] * spatial_zdim[1] if i == latent_levels - 1: mu[i] = layers.conv2D(pre_z[i + resolution_levels - latent_levels], 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity) sigma[i] = layers.conv2D(pre_z[i + resolution_levels - latent_levels], 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) else: for j in reversed(range(0, i + 1)): z_below_ups = layers.bilinear_upsample2D( z_ups_mat[j + 1][i + 1], factor=2, name='ups') z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_1' % ((i + 1), (j + 1)), num_filters=zdim_0 * n0, normalisation=norm, training=training) z_below_ups = layers.conv2D(z_below_ups, name='z%d_ups_to_%d_c_2' % ((i + 1), (j + 1)), num_filters=zdim_0 * n0, normalisation=norm, training=training) z_ups_mat[j][i + 1] = z_below_ups z_input = tf.concat([ pre_z[i + resolution_levels - latent_levels], z_ups_mat[i][i + 1] ], axis=3, name='concat_%d' % i) z_input = layers.conv2D(z_input, 'z%d_input_1' % i, num_filters=num_channels[i], normalisation=norm, training=training) z_input = layers.conv2D(z_input, 'z%d_input_2' % i, num_filters=num_channels[i], normalisation=norm, training=training) mu[i] = layers.conv2D(z_input, 'z%d_mu' % i, num_filters=zdim_0, activation=tf.identity, kernel_size=(1, 1)) sigma[i] = layers.conv2D(z_input, 'z%d_sigma' % i, num_filters=zdim_0, activation=tf.nn.softplus, kernel_size=(1, 1)) z[i] = mu[i] + sigma[i] * tf.random_normal( tf.shape(mu[i]), 0, 1, dtype=tf.float32) z_ups_mat[i][i] = z[i] return z, mu, sigma
def unet2D_i2l(images, nlabels, training_pl, scope_reuse = False): n0 = 16 n1, n2, n3, n4 = 1*n0, 2*n0, 4*n0, 8*n0 with tf.variable_scope('i2l_mapper') as scope: if scope_reuse: scope.reuse_variables() # ==================================== # 1st Conv block - two conv layers, followed by max-pooling # ==================================== conv1_1 = layers.conv2D_layer_bn(x=images, name='conv1_1', num_filters=n1, training = training_pl) conv1_2 = layers.conv2D_layer_bn(x=conv1_1, name='conv1_2', num_filters=n1, training = training_pl) pool1 = layers.max_pool_layer2d(conv1_2) # ==================================== # 2nd Conv block # ==================================== conv2_1 = layers.conv2D_layer_bn(x=pool1, name='conv2_1', num_filters=n2, training = training_pl) conv2_2 = layers.conv2D_layer_bn(x=conv2_1, name='conv2_2', num_filters=n2, training = training_pl) pool2 = layers.max_pool_layer2d(conv2_2) # ==================================== # 3rd Conv block # ==================================== conv3_1 = layers.conv2D_layer_bn(x=pool2, name='conv3_1', num_filters=n3, training = training_pl) conv3_2 = layers.conv2D_layer_bn(x=conv3_1, name='conv3_2', num_filters=n3, training = training_pl) pool3 = layers.max_pool_layer2d(conv3_1) # ==================================== # 4th Conv block # ==================================== conv4_1 = layers.conv2D_layer_bn(x=pool3, name='conv4_1', num_filters=n4, training = training_pl) conv4_2 = layers.conv2D_layer_bn(x=conv4_1, name='conv4_2', num_filters=n4, training = training_pl) # ==================================== # Upsampling via bilinear upsampling, concatenation (skip connection), followed by 2 conv layers # ==================================== deconv3 = layers.bilinear_upsample2D(conv4_2, size = (tf.shape(conv3_2)[1],tf.shape(conv3_2)[2]), name='upconv3') concat3 = tf.concat([deconv3, conv3_2], axis=-1) conv5_1 = layers.conv2D_layer_bn(x=concat3, name='conv5_1', num_filters=n3, training = training_pl) conv5_2 = layers.conv2D_layer_bn(x=conv5_1, name='conv5_2', num_filters=n3, training = training_pl) # ==================================== # Upsampling via bilinear upsampling, concatenation (skip connection), followed by 2 conv layers # ==================================== deconv2 = layers.bilinear_upsample2D(conv5_2, size = (tf.shape(conv2_2)[1],tf.shape(conv2_2)[2]), name='upconv2') concat2 = tf.concat([deconv2, conv2_2], axis=-1) conv6_1 = layers.conv2D_layer_bn(x=concat2, name='conv6_1', num_filters=n2, training = training_pl) conv6_2 = layers.conv2D_layer_bn(x=conv6_1, name='conv6_2', num_filters=n2, training = training_pl) # ==================================== # Upsampling via bilinear upsampling, concatenation (skip connection), followed by 2 conv layers # ==================================== deconv1 = layers.bilinear_upsample2D(conv6_2, size = (tf.shape(conv1_2)[1],tf.shape(conv1_2)[2]), name='upconv1') concat1 = tf.concat([deconv1, conv1_2], axis=-1) conv7_1 = layers.conv2D_layer_bn(x=concat1, name='conv7_1', num_filters=n1, training = training_pl) conv7_2 = layers.conv2D_layer_bn(x=conv7_1, name='conv7_2', num_filters=n1, training = training_pl) # ==================================== # Final conv layer - without batch normalization or activation # ==================================== pred = layers.conv2D_layer(x=conv7_2, name='pred', num_filters=nlabels, kernel_size=1) return pool1, pool2, pool3, conv4_2, conv5_2, conv6_2, conv7_2, pred