Exemplo n.º 1
0
def network(input, is_training, num_outputs_embedding, actual_network, filters=64, levels=7, activation='relu', normalize=False, data_format='channels_first', padding='same'):
    if activation == 'selu':
        activation = tf.nn.selu
        kernel_initializer = selu_initializer
    elif activation == 'relu':
        activation = tf.nn.relu
        kernel_initializer = he_initializer
    elif activation == 'tanh':
        activation = tf.nn.tanh
        kernel_initializer = selu_initializer
    padding = padding
    embedding_axis = 1 if data_format == 'channels_first' else 4
    if normalize:
        embeddings_activation = lambda x, name: tf.nn.l2_normalize(x, dim=embedding_axis, name=name, epsilon=1e-4)
    else:
        if activation == tf.nn.selu:
            embeddings_activation = tf.nn.selu
        else:
            embeddings_activation = None

    embeddings_normalization = lambda x, name: tf.nn.l2_normalize(x, dim=embedding_axis, name=name, epsilon=1e-4)

    with tf.variable_scope('unet_0'):
        unet = actual_network(num_filters_base=filters, kernel=[3, 3, 3], num_levels=levels, data_format=data_format, kernel_initializer=kernel_initializer, activation=activation, is_training=is_training, name='unet', padding=padding)
        unet_out = unet(input, is_training)
        embeddings = conv3d(unet_out, kernel_size=[1, 1, 1], name='embeddings', filters=num_outputs_embedding, kernel_initializer=kernel_initializer, activation=embeddings_activation, data_format=data_format, is_training=is_training, padding=padding)
    with tf.variable_scope('unet_1'):
        normalized_embeddings = embeddings_normalization(embeddings, 'embeddings_normalized')
        input_concat = concat_channels([input, normalized_embeddings], name='input_concat', data_format=data_format)
        unet = actual_network(num_filters_base=filters, kernel=[3, 3, 3], num_levels=levels, data_format=data_format, kernel_initializer=kernel_initializer, activation=activation, is_training=is_training, name='unet', padding=padding)
        unet_out = unet(input_concat, is_training)
        embeddings_2 = conv3d(unet_out, kernel_size=[1, 1, 1], name='embeddings', filters=num_outputs_embedding, kernel_initializer=kernel_initializer, activation=embeddings_activation, data_format=data_format, is_training=is_training, padding=padding)
    return embeddings, embeddings_2
Exemplo n.º 2
0
def network_unet(input,
                 num_heatmaps,
                 is_training,
                 data_format='channels_first'):
    num_filters_base = 64
    activation = tf.nn.relu
    node = conv3d(input,
                  filters=num_filters_base,
                  kernel_size=[3, 3, 3],
                  name='conv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training)
    scnet_local = UnetClassic3D(num_filters_base=num_filters_base,
                                num_levels=5,
                                double_filters_per_level=False,
                                normalization=None,
                                activation=activation,
                                data_format=data_format)
    unet_out = scnet_local(node, is_training)
    heatmaps = conv3d(
        unet_out,
        filters=num_heatmaps,
        kernel_size=[3, 3, 3],
        name='heatmaps',
        kernel_initializer=tf.truncated_normal_initializer(stddev=0.0001),
        activation=None,
        data_format=data_format,
        is_training=is_training)

    return heatmaps, heatmaps, heatmaps
Exemplo n.º 3
0
 def conv(self, node, current_level, postfix, is_training):
     return conv3d(node,
                   self.num_filters(current_level), [3, 3, 3],
                   name='conv' + postfix,
                   activation=self.activation,
                   normalization=self.normalization,
                   is_training=is_training,
                   data_format=self.data_format,
                   padding=self.padding)
Exemplo n.º 4
0
 def conv(self, node, current_level, postfix, is_training):
     return conv3d(node,
                   self.num_filters(current_level),
                   self.kernel_size,
                   name='conv' + postfix,
                   activation=self.activation,
                   normalization=None,
                   is_training=is_training,
                   data_format=self.data_format,
                   kernel_initializer=self.kernel_initializer,
                   padding=self.padding)
Exemplo n.º 5
0
def network_u(input,
              is_training,
              num_labels,
              data_format='channels_first',
              activation='relu',
              padding='same',
              actual_network=None,
              *args,
              **kwargs):
    """
    The U-Net
    :param input: Input tensor.
    :param num_labels: Number of outputs.
    :param is_training: True, if training network.
    :param data_format: 'channels_first' or 'channels_last'
    :param actual_network: The actual u-net instance used as the local appearance network.
    :param padding: Padding parameter passed to the convolution operations.
    :param activation: The activation function. 'relu' or 'selu'
    :param args: Not used.
    :param kwargs: Passed to actual_network()
    :return: prediction
    """
    if activation == 'relu':
        kernel_initializer = he_initializer
        activation = tf.nn.relu
    else:
        kernel_initializer = selu_initializer
        activation = tf.nn.selu
    local_kernel_initializer = tf.initializers.truncated_normal(stddev=0.001)
    local_activation = None
    with tf.variable_scope('local'):
        unet = actual_network(data_format=data_format,
                              kernel_initializer=kernel_initializer,
                              activation=activation,
                              padding=padding,
                              **kwargs)
        prediction = unet(input, is_training=is_training)
        prediction = conv3d(prediction,
                            num_labels, [3, 3, 3],
                            name='output',
                            padding=padding,
                            kernel_initializer=local_kernel_initializer,
                            activation=local_activation,
                            is_training=is_training,
                            data_format=data_format)
    return prediction
def network_unet(input, num_labels, is_training, data_format='channels_first'):
    kernel_initializer = he_initializer
    activation = tf.nn.relu
    local_kernel_initializer = he_initializer
    local_activation = None
    padding = 'reflect'
    with tf.variable_scope('unet'):
        unet = UnetClassicAvgLinear3d(64,
                                      4,
                                      data_format=data_format,
                                      double_filters_per_level=True,
                                      kernel_initializer=kernel_initializer,
                                      activation=activation,
                                      padding=padding)
        prediction = unet(input, is_training=is_training)
        prediction = conv3d(prediction,
                            num_labels, [1, 1, 1],
                            name='output',
                            padding=padding,
                            kernel_initializer=local_kernel_initializer,
                            activation=local_activation,
                            is_training=is_training)
    return prediction, prediction, prediction
Exemplo n.º 7
0
def network_scn(input,
                num_heatmaps,
                is_training,
                data_format='channels_first'):
    num_filters_base = 64
    activation = lambda x, name: tf.nn.leaky_relu(x, name=name, alpha=0.1)
    padding = 'reflect'
    heatmap_layer_kernel_initializer = tf.truncated_normal_initializer(
        stddev=0.001)
    downsampling_factor = 8
    node = conv3d(input,
                  filters=num_filters_base,
                  kernel_size=[3, 3, 3],
                  name='conv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training)
    scnet_local = SCNetLocal(num_filters_base=num_filters_base,
                             num_levels=4,
                             double_filters_per_level=False,
                             normalization=None,
                             activation=activation,
                             data_format=data_format,
                             padding=padding)
    unet_out = scnet_local(node, is_training)
    local_heatmaps = conv3d(
        unet_out,
        filters=num_heatmaps,
        kernel_size=[3, 3, 3],
        name='local_heatmaps',
        kernel_initializer=heatmap_layer_kernel_initializer,
        activation=None,
        data_format=data_format,
        is_training=is_training)
    downsampled = avg_pool3d(local_heatmaps, [downsampling_factor] * 3,
                             name='local_downsampled',
                             data_format=data_format)
    conv = conv3d(downsampled,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv1',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv2',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_heatmaps,
                  kernel_size=[7, 7, 7],
                  name='spatial_downsampled',
                  kernel_initializer=heatmap_layer_kernel_initializer,
                  activation=tf.nn.tanh,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    spatial_heatmaps = upsample3d_cubic(conv, [downsampling_factor] * 3,
                                        name='spatial_heatmaps',
                                        data_format=data_format,
                                        padding='valid_cropped')

    heatmaps = local_heatmaps * spatial_heatmaps

    return heatmaps, local_heatmaps, spatial_heatmaps
def network_scn(input, num_labels, is_training, data_format='channels_first'):
    downsampling_factor = 4
    kernel_initializer = he_initializer
    activation = tf.nn.relu
    local_kernel_initializer = he_initializer
    local_activation = tf.nn.tanh
    spatial_kernel_initializer = he_initializer
    spatial_activation = None
    padding = 'reflect'
    with tf.variable_scope('unet'):
        unet = UnetClassicAvgLinear3d(64,
                                      4,
                                      data_format=data_format,
                                      double_filters_per_level=True,
                                      kernel_initializer=kernel_initializer,
                                      activation=activation,
                                      padding=padding)
        local_prediction = unet(input, is_training=is_training)
        local_prediction = conv3d(local_prediction,
                                  num_labels, [1, 1, 1],
                                  name='local_prediction',
                                  padding=padding,
                                  kernel_initializer=local_kernel_initializer,
                                  activation=local_activation,
                                  is_training=is_training)
    with tf.variable_scope('spatial_configuration'):
        local_prediction_pool = avg_pool3d(local_prediction,
                                           [downsampling_factor] * 3,
                                           name='local_prediction_pool')
        scconv = conv3d(local_prediction_pool,
                        64, [5, 5, 5],
                        name='scconv0',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        scconv = conv3d(scconv,
                        64, [5, 5, 5],
                        name='scconv1',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        scconv = conv3d(scconv,
                        64, [5, 5, 5],
                        name='scconv2',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        spatial_prediction_pool = conv3d(
            scconv,
            num_labels, [5, 5, 5],
            name='spatial_prediction_pool',
            padding=padding,
            kernel_initializer=spatial_kernel_initializer,
            activation=spatial_activation,
            is_training=is_training)
        spatial_prediction = upsample3d_linear(spatial_prediction_pool,
                                               [downsampling_factor] * 3,
                                               name='spatial_prediction',
                                               padding='valid_cropped')
    with tf.variable_scope('combination'):
        prediction = local_prediction * spatial_prediction
    return prediction, local_prediction, spatial_prediction
Exemplo n.º 9
0
def spatial_configuration_net(input,
                              num_labels,
                              is_training,
                              data_format='channels_first',
                              actual_network=None,
                              padding=None,
                              spatial_downsample=8,
                              *args,
                              **kwargs):
    """
    The spatial configuration net.
    :param input: Input tensor.
    :param num_labels: Number of outputs.
    :param is_training: True, if training network.
    :param data_format: 'channels_first' or 'channels_last'
    :param actual_network: The actual u-net instance used as the local appearance network.
    :param padding: Padding parameter passed to the convolution operations.
    :param spatial_downsample: Downsamping factor for the spatial configuration stage.
    :param args: Not used.
    :param kwargs: Not used.
    :return: heatmaps, local_heatmaps, spatial_heatmaps
    """
    num_filters_base = 64
    activation = lambda x, name: tf.nn.leaky_relu(x, name=name, alpha=0.1)
    heatmap_layer_kernel_initializer = tf.truncated_normal_initializer(
        stddev=0.001)
    downsampling_factor = spatial_downsample
    node = conv3d(input,
                  filters=num_filters_base,
                  kernel_size=[3, 3, 3],
                  name='conv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training)
    scnet_local = actual_network(num_filters_base=num_filters_base,
                                 num_levels=4,
                                 double_filters_per_level=False,
                                 normalization=None,
                                 activation=activation,
                                 data_format=data_format,
                                 padding=padding)
    unet_out = scnet_local(node, is_training)
    local_heatmaps = conv3d(
        unet_out,
        filters=num_labels,
        kernel_size=[3, 3, 3],
        name='local_heatmaps',
        kernel_initializer=heatmap_layer_kernel_initializer,
        activation=None,
        data_format=data_format,
        is_training=is_training)
    downsampled = avg_pool3d(local_heatmaps, [downsampling_factor] * 3,
                             name='local_downsampled',
                             data_format=data_format)
    conv = conv3d(downsampled,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv1',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv2',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_labels,
                  kernel_size=[7, 7, 7],
                  name='spatial_downsampled',
                  kernel_initializer=heatmap_layer_kernel_initializer,
                  activation=tf.nn.tanh,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    if data_format == 'channels_last':
        # suppose that 'channels_last' means CPU
        # resize_trilinear is much faster on CPU
        spatial_heatmaps = resize_tricubic(conv,
                                           factors=[downsampling_factor] * 3,
                                           name='spatial_heatmaps',
                                           data_format=data_format)
    else:
        # suppose that 'channels_first' means GPU
        # upsample3d_linear is much faster on GPU
        spatial_heatmaps = upsample3d_cubic(conv,
                                            factors=[downsampling_factor] * 3,
                                            name='spatial_heatmaps',
                                            data_format=data_format,
                                            padding='valid_cropped')

    heatmaps = local_heatmaps * spatial_heatmaps

    return heatmaps, local_heatmaps, spatial_heatmaps