Ejemplo n.º 1
0
def unet3D(x,
           training,
           nlabels,
           n0=32,
           resolution_levels=4,
           norm=tfnorm.batch_norm,
           scope_reuse=False,
           return_net=False,
           **kwargs):

    with tf.variable_scope('segmenter') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.maxpool3D(enc[ii - 1][-1]))

                enc[ii].append(
                    layers.conv3D(enc[ii][-1],
                                  'conv_%d_1' % ii,
                                  num_filters=n0 * (2**ii),
                                  training=training,
                                  normalisation=norm,
                                  add_bias=add_bias))
                enc[ii].append(
                    layers.conv3D(enc[ii][-1],
                                  'conv_%d_2' % ii,
                                  num_filters=n0 * (2**ii),
                                  training=training,
                                  normalisation=norm,
                                  add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(
                    layers.transposed_conv3D(next_inp,
                                             name='upconv_%d' % jj,
                                             num_filters=nlabels,
                                             training=training,
                                             normalisation=norm,
                                             add_bias=add_bias))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=4))

                dec[jj].append(
                    layers.conv3D(dec[jj][-1],
                                  'conv_%d_1' % jj,
                                  num_filters=n0 * (2**ii),
                                  training=training,
                                  normalisation=norm,
                                  add_bias=add_bias))
                dec[jj].append(
                    layers.conv3D(dec[jj][-1],
                                  'conv_%d_2' % jj,
                                  num_filters=n0 * (2**ii),
                                  training=training,
                                  normalisation=norm,
                                  add_bias=add_bias))

        output = layers.conv3D(dec[-1][-1],
                               'prediction',
                               num_filters=nlabels,
                               kernel_size=(1, 1, 1),
                               activation=tf.identity,
                               training=training,
                               normalisation=norm,
                               add_bias=add_bias)

        dec[-1].append(output)

        if return_net:
            net = enc + dec
            return output, net

        return output
Ejemplo n.º 2
0
def prob_unet2D(z_list,
                training,
                image_size,
                n_classes,
                scope_reuse=False,
                norm=tfnorm.batch_norm,
                **kwargs):
    x = kwargs.get('x')

    z = z_list[0]

    resolution_levels = kwargs.get('resolution_levels', 7)
    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    conv_unit = layers.conv2D
    deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2)

    bs = tf.shape(x)[0]
    zdim = z.get_shape().as_list()[-1]

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        z_t = tf.reshape(z, tf.stack((bs, 1, 1, zdim)))

        broadcast_z = tf.tile(z_t, (1, image_size[0], image_size[1], 1))

        net = tf.concat([dec[-1][-1], broadcast_z], axis=-1)

        recomb = conv_unit(net,
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        s = [
            layers.conv2D(recomb,
                          'prediction',
                          num_filters=n_classes,
                          kernel_size=(1, 1),
                          activation=tf.identity)
        ]

        return s
Ejemplo n.º 3
0
def proposed(z_list,
             training,
             image_size,
             n_classes,
             scope_reuse=False,
             norm=tfnorm.batch_norm,
             rank=10,
             diagonal=False,
             **kwargs):
    x = kwargs.get('x')

    resolution_levels = kwargs.get('resolution_levels', 7)
    n0 = kwargs.get('n0', 32)
    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    conv_unit = layers.conv2D
    deconv_unit = lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2)

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        net = dec[-1][-1]

        recomb = conv_unit(net,
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        epsilon = 1e-5

        mean = layers.conv2D(recomb,
                             'mean',
                             num_filters=n_classes,
                             kernel_size=(1, 1),
                             activation=tf.identity)
        log_cov_diag = layers.conv2D(recomb,
                                     'diag',
                                     num_filters=n_classes,
                                     kernel_size=(1, 1),
                                     activation=tf.identity)
        cov_factor = layers.conv2D(recomb,
                                   'factor',
                                   num_filters=n_classes * rank,
                                   kernel_size=(1, 1),
                                   activation=tf.identity)

        shape = image_size[:-1] + (n_classes, )
        flat_size = np.prod(shape)
        mean = tf.reshape(mean, [-1, flat_size])
        cov_diag = tf.exp(tf.reshape(log_cov_diag, [-1, flat_size])) + epsilon
        cov_factor = tf.reshape(cov_factor, [-1, flat_size, rank])
        if diagonal:
            dist = DiagonalMultivariateNormal(loc=mean, cov_diag=cov_diag)
        else:
            dist = LowRankMultivariateNormal(loc=mean,
                                             cov_diag=cov_diag,
                                             cov_factor=cov_factor)

        s = dist.rsample((1, ))
        s = tf.reshape(s, (-1, ) + shape)
        return [[s], dist]
Ejemplo n.º 4
0
def prob_unet2D_arch(
        x,
        training,
        nlabels,
        n0=32,
        resolution_levels=7,
        norm=tfnorm.batch_norm,
        conv_unit=layers.conv2D,
        deconv_unit=lambda inp: layers.bilinear_upsample2D(inp, 'upsample', 2),
        scope_reuse=False,
        return_net=False):

    num_channels = [n0, 2 * n0, 4 * n0, 6 * n0, 6 * n0, 6 * n0, 6 * n0]

    with tf.variable_scope('likelihood') as scope:

        if scope_reuse:
            scope.reuse_variables()

        add_bias = False if norm == tfnorm.batch_norm else True

        enc = []

        with tf.variable_scope('encoder'):

            for ii in range(resolution_levels):

                enc.append([])

                # In first layer set input to x rather than max pooling
                if ii == 0:
                    enc[ii].append(x)
                else:
                    enc[ii].append(layers.averagepool2D(enc[ii - 1][-1]))

                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_1' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_2' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                enc[ii].append(
                    conv_unit(enc[ii][-1],
                              'conv_%d_3' % ii,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        dec = []

        with tf.variable_scope('decoder'):

            for jj in range(resolution_levels - 1):

                ii = resolution_levels - jj - 1  # used to index the encoder again

                dec.append([])

                if jj == 0:
                    next_inp = enc[ii][-1]
                else:
                    next_inp = dec[jj - 1][-1]

                dec[jj].append(deconv_unit(next_inp))

                # skip connection
                dec[jj].append(
                    layers.crop_and_concat([dec[jj][-1], enc[ii - 1][-1]],
                                           axis=3))

                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_1' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias)
                )  # projection True to make it work with res units.
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_2' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))
                dec[jj].append(
                    conv_unit(dec[jj][-1],
                              'conv_%d_3' % jj,
                              num_filters=num_channels[ii],
                              training=training,
                              normalisation=norm,
                              add_bias=add_bias))

        recomb = conv_unit(dec[-1][-1],
                           'recomb_0',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_1',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)
        recomb = conv_unit(recomb,
                           'recomb_2',
                           num_filters=num_channels[0],
                           kernel_size=(1, 1),
                           training=training,
                           normalisation=norm,
                           add_bias=add_bias)

        s = layers.conv2D(recomb,
                          'prediction',
                          num_filters=nlabels,
                          kernel_size=(1, 1),
                          activation=tf.identity)

        return s