def unet3D(x, training, nlabels, n0=32, resolution_levels=4, norm=tfnorm.batch_norm, scope_reuse=False, return_net=False, **kwargs): with tf.variable_scope('segmenter') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.maxpool3D(enc[ii - 1][-1])) enc[ii].append( layers.conv3D(enc[ii][-1], 'conv_%d_1' % ii, num_filters=n0 * (2**ii), training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( layers.conv3D(enc[ii][-1], 'conv_%d_2' % ii, num_filters=n0 * (2**ii), training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append( layers.transposed_conv3D(next_inp, name='upconv_%d' % jj, num_filters=nlabels, training=training, normalisation=norm, add_bias=add_bias)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=4)) dec[jj].append( layers.conv3D(dec[jj][-1], 'conv_%d_1' % jj, num_filters=n0 * (2**ii), training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( layers.conv3D(dec[jj][-1], 'conv_%d_2' % jj, num_filters=n0 * (2**ii), training=training, normalisation=norm, add_bias=add_bias)) output = layers.conv3D(dec[-1][-1], 'prediction', num_filters=nlabels, kernel_size=(1, 1, 1), activation=tf.identity, training=training, normalisation=norm, add_bias=add_bias) dec[-1].append(output) if return_net: net = enc + dec return output, net return output
def prob_unet2D(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, **kwargs): x = kwargs.get('x') z = z_list[0] resolution_levels = kwargs.get('resolution_levels', 7) n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] conv_unit = layers.conv2D deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2) bs = tf.shape(x)[0] zdim = z.get_shape().as_list()[-1] with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) z_t = tf.reshape(z, tf.stack((bs, 1, 1, zdim))) broadcast_z = tf.tile(z_t, (1, image_size[0], image_size[1], 1)) net = tf.concat([dec[-1][-1], broadcast_z], axis=-1) recomb = conv_unit(net, 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) s = [ layers.conv2D(recomb, 'prediction', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) ] return s
def proposed(z_list, training, image_size, n_classes, scope_reuse=False, norm=tfnorm.batch_norm, rank=10, diagonal=False, **kwargs): x = kwargs.get('x') resolution_levels = kwargs.get('resolution_levels', 7) n0 = kwargs.get('n0', 32) num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] conv_unit = layers.conv2D deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2) with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) net = dec[-1][-1] recomb = conv_unit(net, 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) epsilon = 1e-5 mean = layers.conv2D(recomb, 'mean', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) log_cov_diag = layers.conv2D(recomb, 'diag', num_filters=n_classes, kernel_size=(1, 1), activation=tf.identity) cov_factor = layers.conv2D(recomb, 'factor', num_filters=n_classes * rank, kernel_size=(1, 1), activation=tf.identity) shape = image_size[:-1] + (n_classes, ) flat_size = np.prod(shape) mean = tf.reshape(mean, [-1, flat_size]) cov_diag = tf.exp(tf.reshape(log_cov_diag, [-1, flat_size])) + epsilon cov_factor = tf.reshape(cov_factor, [-1, flat_size, rank]) if diagonal: dist = DiagonalMultivariateNormal(loc=mean, cov_diag=cov_diag) else: dist = LowRankMultivariateNormal(loc=mean, cov_diag=cov_diag, cov_factor=cov_factor) s = dist.rsample((1, )) s = tf.reshape(s, (-1, ) + shape) return [[s], dist]
def prob_unet2D_arch( x, training, nlabels, n0=32, resolution_levels=7, norm=tfnorm.batch_norm, conv_unit=layers.conv2D, deconv_unit=lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2), scope_reuse=False, return_net=False): num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0] with tf.variable_scope('likelihood') as scope: if scope_reuse: scope.reuse_variables() add_bias = False if norm == tfnorm.batch_norm else True enc = [] with tf.variable_scope('encoder'): for ii in range(resolution_levels): enc.append([]) # In first layer set input to x rather than max pooling if ii == 0: enc[ii].append(x) else: enc[ii].append(layers.averagepool2D(enc[ii - 1][-1])) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_1' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_2' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) enc[ii].append( conv_unit(enc[ii][-1], 'conv_%d_3' % ii, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec = [] with tf.variable_scope('decoder'): for jj in range(resolution_levels - 1): ii = resolution_levels - jj - 1 # used to index the encoder again dec.append([]) if jj == 0: next_inp = enc[ii][-1] else: next_inp = dec[jj - 1][-1] dec[jj].append(deconv_unit(next_inp)) # skip connection dec[jj].append( layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]], axis=3)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_1' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias) ) # projection True to make it work with res units. dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_2' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) dec[jj].append( conv_unit(dec[jj][-1], 'conv_%d_3' % jj, num_filters=num_channels[ii], training=training, normalisation=norm, add_bias=add_bias)) recomb = conv_unit(dec[-1][-1], 'recomb_0', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_1', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) recomb = conv_unit(recomb, 'recomb_2', num_filters=num_channels[0], kernel_size=(1, 1), training=training, normalisation=norm, add_bias=add_bias) s = layers.conv2D(recomb, 'prediction', num_filters=nlabels, kernel_size=(1, 1), activation=tf.identity) return s