Beispiel #1
0
 def downsample(self, node, current_level, is_training):
     return avg_pool3d(node, [2, 2, 2],
                       name='downsample' + str(current_level),
                       data_format=self.data_format)
def network_scn(input, num_labels, is_training, data_format='channels_first'):
    downsampling_factor = 4
    kernel_initializer = he_initializer
    activation = tf.nn.relu
    local_kernel_initializer = he_initializer
    local_activation = tf.nn.tanh
    spatial_kernel_initializer = he_initializer
    spatial_activation = None
    padding = 'reflect'
    with tf.variable_scope('unet'):
        unet = UnetClassicAvgLinear3d(64,
                                      4,
                                      data_format=data_format,
                                      double_filters_per_level=True,
                                      kernel_initializer=kernel_initializer,
                                      activation=activation,
                                      padding=padding)
        local_prediction = unet(input, is_training=is_training)
        local_prediction = conv3d(local_prediction,
                                  num_labels, [1, 1, 1],
                                  name='local_prediction',
                                  padding=padding,
                                  kernel_initializer=local_kernel_initializer,
                                  activation=local_activation,
                                  is_training=is_training)
    with tf.variable_scope('spatial_configuration'):
        local_prediction_pool = avg_pool3d(local_prediction,
                                           [downsampling_factor] * 3,
                                           name='local_prediction_pool')
        scconv = conv3d(local_prediction_pool,
                        64, [5, 5, 5],
                        name='scconv0',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        scconv = conv3d(scconv,
                        64, [5, 5, 5],
                        name='scconv1',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        scconv = conv3d(scconv,
                        64, [5, 5, 5],
                        name='scconv2',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        spatial_prediction_pool = conv3d(
            scconv,
            num_labels, [5, 5, 5],
            name='spatial_prediction_pool',
            padding=padding,
            kernel_initializer=spatial_kernel_initializer,
            activation=spatial_activation,
            is_training=is_training)
        spatial_prediction = upsample3d_linear(spatial_prediction_pool,
                                               [downsampling_factor] * 3,
                                               name='spatial_prediction',
                                               padding='valid_cropped')
    with tf.variable_scope('combination'):
        prediction = local_prediction * spatial_prediction
    return prediction, local_prediction, spatial_prediction
Beispiel #3
0
def network_scn(input,
                num_heatmaps,
                is_training,
                data_format='channels_first'):
    num_filters_base = 64
    activation = lambda x, name: tf.nn.leaky_relu(x, name=name, alpha=0.1)
    padding = 'reflect'
    heatmap_layer_kernel_initializer = tf.truncated_normal_initializer(
        stddev=0.001)
    downsampling_factor = 8
    node = conv3d(input,
                  filters=num_filters_base,
                  kernel_size=[3, 3, 3],
                  name='conv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training)
    scnet_local = SCNetLocal(num_filters_base=num_filters_base,
                             num_levels=4,
                             double_filters_per_level=False,
                             normalization=None,
                             activation=activation,
                             data_format=data_format,
                             padding=padding)
    unet_out = scnet_local(node, is_training)
    local_heatmaps = conv3d(
        unet_out,
        filters=num_heatmaps,
        kernel_size=[3, 3, 3],
        name='local_heatmaps',
        kernel_initializer=heatmap_layer_kernel_initializer,
        activation=None,
        data_format=data_format,
        is_training=is_training)
    downsampled = avg_pool3d(local_heatmaps, [downsampling_factor] * 3,
                             name='local_downsampled',
                             data_format=data_format)
    conv = conv3d(downsampled,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv1',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv2',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_heatmaps,
                  kernel_size=[7, 7, 7],
                  name='spatial_downsampled',
                  kernel_initializer=heatmap_layer_kernel_initializer,
                  activation=tf.nn.tanh,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    spatial_heatmaps = upsample3d_cubic(conv, [downsampling_factor] * 3,
                                        name='spatial_heatmaps',
                                        data_format=data_format,
                                        padding='valid_cropped')

    heatmaps = local_heatmaps * spatial_heatmaps

    return heatmaps, local_heatmaps, spatial_heatmaps
Beispiel #4
0
def spatial_configuration_net(input,
                              num_labels,
                              is_training,
                              data_format='channels_first',
                              actual_network=None,
                              padding=None,
                              spatial_downsample=8,
                              *args,
                              **kwargs):
    """
    The spatial configuration net.
    :param input: Input tensor.
    :param num_labels: Number of outputs.
    :param is_training: True, if training network.
    :param data_format: 'channels_first' or 'channels_last'
    :param actual_network: The actual u-net instance used as the local appearance network.
    :param padding: Padding parameter passed to the convolution operations.
    :param spatial_downsample: Downsamping factor for the spatial configuration stage.
    :param args: Not used.
    :param kwargs: Not used.
    :return: heatmaps, local_heatmaps, spatial_heatmaps
    """
    num_filters_base = 64
    activation = lambda x, name: tf.nn.leaky_relu(x, name=name, alpha=0.1)
    heatmap_layer_kernel_initializer = tf.truncated_normal_initializer(
        stddev=0.001)
    downsampling_factor = spatial_downsample
    node = conv3d(input,
                  filters=num_filters_base,
                  kernel_size=[3, 3, 3],
                  name='conv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training)
    scnet_local = actual_network(num_filters_base=num_filters_base,
                                 num_levels=4,
                                 double_filters_per_level=False,
                                 normalization=None,
                                 activation=activation,
                                 data_format=data_format,
                                 padding=padding)
    unet_out = scnet_local(node, is_training)
    local_heatmaps = conv3d(
        unet_out,
        filters=num_labels,
        kernel_size=[3, 3, 3],
        name='local_heatmaps',
        kernel_initializer=heatmap_layer_kernel_initializer,
        activation=None,
        data_format=data_format,
        is_training=is_training)
    downsampled = avg_pool3d(local_heatmaps, [downsampling_factor] * 3,
                             name='local_downsampled',
                             data_format=data_format)
    conv = conv3d(downsampled,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv1',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_filters_base,
                  kernel_size=[7, 7, 7],
                  name='sconv2',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv3d(conv,
                  filters=num_labels,
                  kernel_size=[7, 7, 7],
                  name='spatial_downsampled',
                  kernel_initializer=heatmap_layer_kernel_initializer,
                  activation=tf.nn.tanh,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    if data_format == 'channels_last':
        # suppose that 'channels_last' means CPU
        # resize_trilinear is much faster on CPU
        spatial_heatmaps = resize_tricubic(conv,
                                           factors=[downsampling_factor] * 3,
                                           name='spatial_heatmaps',
                                           data_format=data_format)
    else:
        # suppose that 'channels_first' means GPU
        # upsample3d_linear is much faster on GPU
        spatial_heatmaps = upsample3d_cubic(conv,
                                            factors=[downsampling_factor] * 3,
                                            name='spatial_heatmaps',
                                            data_format=data_format,
                                            padding='valid_cropped')

    heatmaps = local_heatmaps * spatial_heatmaps

    return heatmaps, local_heatmaps, spatial_heatmaps