def network_downsampling(input,
                         num_landmarks,
                         is_training,
                         data_format='channels_first'):
    num_filters = 128
    kernel_size = [5, 5]
    num_levels = 3
    padding = 'same'
    kernel_initializer = he_initializer
    activation = tf.nn.relu
    heatmap_initializer = tf.truncated_normal_initializer(stddev=0.0001)
    heatmap_activation = None
    node = input
    with tf.variable_scope('downsampling'):
        for i in range(num_levels):
            with tf.variable_scope('level' + str(i)):
                node = conv2d(node,
                              num_filters,
                              kernel_size=kernel_size,
                              name='conv0',
                              activation=activation,
                              kernel_initializer=kernel_initializer,
                              padding=padding,
                              data_format=data_format,
                              is_training=is_training)
                node = conv2d(node,
                              num_filters,
                              kernel_size=kernel_size,
                              name='conv1',
                              activation=activation,
                              kernel_initializer=kernel_initializer,
                              padding=padding,
                              data_format=data_format,
                              is_training=is_training)
                if i != num_levels - 1:
                    node = avg_pool2d(node, [2, 2],
                                      name='downsampling',
                                      data_format=data_format)
        heatmaps = conv2d(node,
                          num_landmarks,
                          kernel_size=[1, 1],
                          name='heatmaps',
                          activation=heatmap_activation,
                          kernel_initializer=heatmap_initializer,
                          padding=padding,
                          data_format=data_format,
                          is_training=is_training)
    return heatmaps
 def downsample(self, node, current_level, is_training):
     return avg_pool2d(node, [2] * 2,
                       name='downsample' + str(current_level),
                       data_format=self.data_format)
def network_scn_mia(input,
                    num_landmarks,
                    is_training,
                    data_format='channels_first'):
    num_filters_base = 128
    activation = lambda x, name: tf.nn.leaky_relu(x, name=name, alpha=0.1)
    padding = 'same'
    heatmap_layer_kernel_initializer = tf.truncated_normal_initializer(
        stddev=0.0001)
    downsampling_factor = 16
    dim = 2
    node = conv2d(input,
                  filters=num_filters_base,
                  kernel_size=[3] * dim,
                  name='conv0',
                  activation=activation,
                  kernel_initializer=he_initializer,
                  data_format=data_format,
                  is_training=is_training)
    scnet_local = SCNetLocal(num_filters_base=num_filters_base,
                             num_levels=4,
                             double_filters_per_level=False,
                             normalization=None,
                             kernel_initializer=he_initializer,
                             activation=activation,
                             data_format=data_format,
                             padding=padding)
    unet_out = scnet_local(node, is_training)
    local_heatmaps = conv2d(
        unet_out,
        filters=num_landmarks,
        kernel_size=[3] * dim,
        name='local_heatmaps',
        kernel_initializer=heatmap_layer_kernel_initializer,
        activation=None,
        data_format=data_format,
        is_training=is_training)
    downsampled = avg_pool2d(local_heatmaps, [downsampling_factor] * dim,
                             name='local_downsampled',
                             data_format=data_format)
    conv = conv2d(downsampled,
                  filters=num_filters_base,
                  kernel_size=[11] * dim,
                  kernel_initializer=he_initializer,
                  name='sconv0',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv2d(conv,
                  filters=num_filters_base,
                  kernel_size=[11] * dim,
                  kernel_initializer=he_initializer,
                  name='sconv1',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv2d(conv,
                  filters=num_filters_base,
                  kernel_size=[11] * dim,
                  kernel_initializer=he_initializer,
                  name='sconv2',
                  activation=activation,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    conv = conv2d(conv,
                  filters=num_landmarks,
                  kernel_size=[11] * dim,
                  name='spatial_downsampled',
                  kernel_initializer=heatmap_layer_kernel_initializer,
                  activation=tf.nn.tanh,
                  data_format=data_format,
                  is_training=is_training,
                  padding=padding)
    spatial_heatmaps = upsample2d_cubic(conv,
                                        factors=[downsampling_factor] * dim,
                                        name='spatial_heatmaps',
                                        data_format=data_format,
                                        padding='valid_cropped')

    heatmaps = local_heatmaps * spatial_heatmaps

    return heatmaps
def network_scn(input,
                num_landmarks,
                is_training,
                data_format='channels_first'):
    num_filters = 128
    local_kernel_size = [5, 5]
    spatial_kernel_size = [15, 15]
    downsampling_factor = 8
    padding = 'same'
    kernel_initializer = he_initializer
    activation = tf.nn.relu
    heatmap_initializer = tf.truncated_normal_initializer(stddev=0.0001)
    local_activation = None
    spatial_activation = None
    with tf.variable_scope('local_appearance'):
        node = conv2d(input,
                      num_filters,
                      kernel_size=local_kernel_size,
                      name='conv1',
                      activation=activation,
                      kernel_initializer=kernel_initializer,
                      padding=padding,
                      data_format=data_format,
                      is_training=is_training)
        node = conv2d(node,
                      num_filters,
                      kernel_size=local_kernel_size,
                      name='conv2',
                      activation=activation,
                      kernel_initializer=kernel_initializer,
                      padding=padding,
                      data_format=data_format,
                      is_training=is_training)
        node = conv2d(node,
                      num_filters,
                      kernel_size=local_kernel_size,
                      name='conv3',
                      activation=activation,
                      kernel_initializer=kernel_initializer,
                      padding=padding,
                      data_format=data_format,
                      is_training=is_training)
        local_heatmaps = conv2d(node,
                                num_landmarks,
                                kernel_size=local_kernel_size,
                                name='local_heatmaps',
                                activation=local_activation,
                                kernel_initializer=heatmap_initializer,
                                padding=padding,
                                data_format=data_format,
                                is_training=is_training)
    with tf.variable_scope('spatial_configuration'):
        local_heatmaps_downsampled = avg_pool2d(
            local_heatmaps, [downsampling_factor, downsampling_factor],
            name='local_heatmaps_downsampled',
            data_format=data_format)
        channel_axis = get_channel_index(local_heatmaps_downsampled,
                                         data_format)
        local_heatmaps_downsampled_split = tf.split(local_heatmaps_downsampled,
                                                    num_landmarks,
                                                    channel_axis)
        spatial_heatmaps_downsampled_split = []
        for i in range(num_landmarks):
            local_heatmaps_except_i = tf.concat([
                local_heatmaps_downsampled_split[j]
                for j in range(num_landmarks) if i != j
            ],
                                                name='h_app_except_' + str(i),
                                                axis=channel_axis)
            h_acc = conv2d(local_heatmaps_except_i,
                           1,
                           kernel_size=spatial_kernel_size,
                           name='h_acc_' + str(i),
                           activation=spatial_activation,
                           kernel_initializer=heatmap_initializer,
                           padding=padding,
                           data_format=data_format,
                           is_training=is_training)
            spatial_heatmaps_downsampled_split.append(h_acc)
        spatial_heatmaps_downsampled = tf.concat(
            spatial_heatmaps_downsampled_split,
            name='spatial_heatmaps_downsampled',
            axis=channel_axis)
        spatial_heatmaps = upsample2d_linear(
            spatial_heatmaps_downsampled,
            [downsampling_factor, downsampling_factor],
            name='spatial_prediction',
            padding='valid_cropped',
            data_format=data_format)
    with tf.variable_scope('combination'):
        heatmaps = local_heatmaps * spatial_heatmaps
    return heatmaps
def network_scn_mmwhs(input,
                      num_landmarks,
                      is_training,
                      data_format='channels_first'):
    downsampling_factor = 8
    num_filters = 128
    num_levels = 4
    spatial_kernel_size = [5, 5]
    kernel_initializer = he_initializer
    activation = tf.nn.relu
    local_kernel_initializer = tf.truncated_normal_initializer(stddev=0.0001)
    local_activation = tf.nn.tanh
    spatial_kernel_initializer = tf.truncated_normal_initializer(stddev=0.0001)
    spatial_activation = None
    padding = 'reflect'
    with tf.variable_scope('unet'):
        unet = UnetClassicAvgLinear2D(num_filters,
                                      num_levels,
                                      data_format=data_format,
                                      double_filters_per_level=False,
                                      kernel_initializer=kernel_initializer,
                                      activation=activation,
                                      padding=padding)
        local_prediction = unet(input, is_training=is_training)
        local_prediction = conv2d(local_prediction,
                                  num_landmarks, [1, 1],
                                  name='local_prediction',
                                  padding=padding,
                                  kernel_initializer=local_kernel_initializer,
                                  activation=local_activation,
                                  is_training=is_training)
    with tf.variable_scope('spatial_configuration'):
        local_prediction_pool = avg_pool2d(local_prediction,
                                           [downsampling_factor] * 2,
                                           name='local_prediction_pool')
        scconv = conv2d(local_prediction_pool,
                        num_filters,
                        spatial_kernel_size,
                        name='scconv0',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        scconv = conv2d(scconv,
                        num_filters,
                        spatial_kernel_size,
                        name='scconv1',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        scconv = conv2d(scconv,
                        num_filters,
                        spatial_kernel_size,
                        name='scconv2',
                        padding=padding,
                        kernel_initializer=kernel_initializer,
                        activation=activation,
                        is_training=is_training)
        spatial_prediction_pool = conv2d(
            scconv,
            num_landmarks,
            spatial_kernel_size,
            name='spatial_prediction_pool',
            padding=padding,
            kernel_initializer=spatial_kernel_initializer,
            activation=spatial_activation,
            is_training=is_training)
        spatial_prediction = upsample2d_linear(spatial_prediction_pool,
                                               [downsampling_factor] * 2,
                                               name='spatial_prediction',
                                               padding='valid_cropped')
    with tf.variable_scope('combination'):
        prediction = local_prediction * spatial_prediction
    return prediction