def network_downsampling(input, num_landmarks, is_training, data_format='channels_first'): num_filters = 128 kernel_size = [5, 5] num_levels = 3 padding = 'same' kernel_initializer = he_initializer activation = tf.nn.relu heatmap_initializer = tf.truncated_normal_initializer(stddev=0.0001) heatmap_activation = None node = input with tf.variable_scope('downsampling'): for i in range(num_levels): with tf.variable_scope('level' + str(i)): node = conv2d(node, num_filters, kernel_size=kernel_size, name='conv0', activation=activation, kernel_initializer=kernel_initializer, padding=padding, data_format=data_format, is_training=is_training) node = conv2d(node, num_filters, kernel_size=kernel_size, name='conv1', activation=activation, kernel_initializer=kernel_initializer, padding=padding, data_format=data_format, is_training=is_training) if i != num_levels - 1: node = avg_pool2d(node, [2, 2], name='downsampling', data_format=data_format) heatmaps = conv2d(node, num_landmarks, kernel_size=[1, 1], name='heatmaps', activation=heatmap_activation, kernel_initializer=heatmap_initializer, padding=padding, data_format=data_format, is_training=is_training) return heatmaps
def downsample(self, node, current_level, is_training): return avg_pool2d(node, [2] * 2, name='downsample' + str(current_level), data_format=self.data_format)
def network_scn_mia(input, num_landmarks, is_training, data_format='channels_first'): num_filters_base = 128 activation = lambda x, name: tf.nn.leaky_relu(x, name=name, alpha=0.1) padding = 'same' heatmap_layer_kernel_initializer = tf.truncated_normal_initializer( stddev=0.0001) downsampling_factor = 16 dim = 2 node = conv2d(input, filters=num_filters_base, kernel_size=[3] * dim, name='conv0', activation=activation, kernel_initializer=he_initializer, data_format=data_format, is_training=is_training) scnet_local = SCNetLocal(num_filters_base=num_filters_base, num_levels=4, double_filters_per_level=False, normalization=None, kernel_initializer=he_initializer, activation=activation, data_format=data_format, padding=padding) unet_out = scnet_local(node, is_training) local_heatmaps = conv2d( unet_out, filters=num_landmarks, kernel_size=[3] * dim, name='local_heatmaps', kernel_initializer=heatmap_layer_kernel_initializer, activation=None, data_format=data_format, is_training=is_training) downsampled = avg_pool2d(local_heatmaps, [downsampling_factor] * dim, name='local_downsampled', data_format=data_format) conv = conv2d(downsampled, filters=num_filters_base, kernel_size=[11] * dim, kernel_initializer=he_initializer, name='sconv0', activation=activation, data_format=data_format, is_training=is_training, padding=padding) conv = conv2d(conv, filters=num_filters_base, kernel_size=[11] * dim, kernel_initializer=he_initializer, name='sconv1', activation=activation, data_format=data_format, is_training=is_training, padding=padding) conv = conv2d(conv, filters=num_filters_base, kernel_size=[11] * dim, kernel_initializer=he_initializer, name='sconv2', activation=activation, data_format=data_format, is_training=is_training, padding=padding) conv = conv2d(conv, filters=num_landmarks, kernel_size=[11] * dim, name='spatial_downsampled', kernel_initializer=heatmap_layer_kernel_initializer, activation=tf.nn.tanh, data_format=data_format, is_training=is_training, padding=padding) spatial_heatmaps = upsample2d_cubic(conv, factors=[downsampling_factor] * dim, name='spatial_heatmaps', data_format=data_format, padding='valid_cropped') heatmaps = local_heatmaps * spatial_heatmaps return heatmaps
def network_scn(input, num_landmarks, is_training, data_format='channels_first'): num_filters = 128 local_kernel_size = [5, 5] spatial_kernel_size = [15, 15] downsampling_factor = 8 padding = 'same' kernel_initializer = he_initializer activation = tf.nn.relu heatmap_initializer = tf.truncated_normal_initializer(stddev=0.0001) local_activation = None spatial_activation = None with tf.variable_scope('local_appearance'): node = conv2d(input, num_filters, kernel_size=local_kernel_size, name='conv1', activation=activation, kernel_initializer=kernel_initializer, padding=padding, data_format=data_format, is_training=is_training) node = conv2d(node, num_filters, kernel_size=local_kernel_size, name='conv2', activation=activation, kernel_initializer=kernel_initializer, padding=padding, data_format=data_format, is_training=is_training) node = conv2d(node, num_filters, kernel_size=local_kernel_size, name='conv3', activation=activation, kernel_initializer=kernel_initializer, padding=padding, data_format=data_format, is_training=is_training) local_heatmaps = conv2d(node, num_landmarks, kernel_size=local_kernel_size, name='local_heatmaps', activation=local_activation, kernel_initializer=heatmap_initializer, padding=padding, data_format=data_format, is_training=is_training) with tf.variable_scope('spatial_configuration'): local_heatmaps_downsampled = avg_pool2d( local_heatmaps, [downsampling_factor, downsampling_factor], name='local_heatmaps_downsampled', data_format=data_format) channel_axis = get_channel_index(local_heatmaps_downsampled, data_format) local_heatmaps_downsampled_split = tf.split(local_heatmaps_downsampled, num_landmarks, channel_axis) spatial_heatmaps_downsampled_split = [] for i in range(num_landmarks): local_heatmaps_except_i = tf.concat([ local_heatmaps_downsampled_split[j] for j in range(num_landmarks) if i != j ], name='h_app_except_' + str(i), axis=channel_axis) h_acc = conv2d(local_heatmaps_except_i, 1, kernel_size=spatial_kernel_size, name='h_acc_' + str(i), activation=spatial_activation, kernel_initializer=heatmap_initializer, padding=padding, data_format=data_format, is_training=is_training) spatial_heatmaps_downsampled_split.append(h_acc) spatial_heatmaps_downsampled = tf.concat( spatial_heatmaps_downsampled_split, name='spatial_heatmaps_downsampled', axis=channel_axis) spatial_heatmaps = upsample2d_linear( spatial_heatmaps_downsampled, [downsampling_factor, downsampling_factor], name='spatial_prediction', padding='valid_cropped', data_format=data_format) with tf.variable_scope('combination'): heatmaps = local_heatmaps * spatial_heatmaps return heatmaps
def network_scn_mmwhs(input, num_landmarks, is_training, data_format='channels_first'): downsampling_factor = 8 num_filters = 128 num_levels = 4 spatial_kernel_size = [5, 5] kernel_initializer = he_initializer activation = tf.nn.relu local_kernel_initializer = tf.truncated_normal_initializer(stddev=0.0001) local_activation = tf.nn.tanh spatial_kernel_initializer = tf.truncated_normal_initializer(stddev=0.0001) spatial_activation = None padding = 'reflect' with tf.variable_scope('unet'): unet = UnetClassicAvgLinear2D(num_filters, num_levels, data_format=data_format, double_filters_per_level=False, kernel_initializer=kernel_initializer, activation=activation, padding=padding) local_prediction = unet(input, is_training=is_training) local_prediction = conv2d(local_prediction, num_landmarks, [1, 1], name='local_prediction', padding=padding, kernel_initializer=local_kernel_initializer, activation=local_activation, is_training=is_training) with tf.variable_scope('spatial_configuration'): local_prediction_pool = avg_pool2d(local_prediction, [downsampling_factor] * 2, name='local_prediction_pool') scconv = conv2d(local_prediction_pool, num_filters, spatial_kernel_size, name='scconv0', padding=padding, kernel_initializer=kernel_initializer, activation=activation, is_training=is_training) scconv = conv2d(scconv, num_filters, spatial_kernel_size, name='scconv1', padding=padding, kernel_initializer=kernel_initializer, activation=activation, is_training=is_training) scconv = conv2d(scconv, num_filters, spatial_kernel_size, name='scconv2', padding=padding, kernel_initializer=kernel_initializer, activation=activation, is_training=is_training) spatial_prediction_pool = conv2d( scconv, num_landmarks, spatial_kernel_size, name='spatial_prediction_pool', padding=padding, kernel_initializer=spatial_kernel_initializer, activation=spatial_activation, is_training=is_training) spatial_prediction = upsample2d_linear(spatial_prediction_pool, [downsampling_factor] * 2, name='spatial_prediction', padding='valid_cropped') with tf.variable_scope('combination'): prediction = local_prediction * spatial_prediction return prediction