def generalized_dice_loss(labels, logits=None, logits_as_probability=None, data_format='channels_first', weights=None, weight_labels=True, squared=True, weight_epsilon=1e-08, epsilon=1e-08): """ Taken from Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations (https://arxiv.org/pdf/1707.03237.pdf) :param labels: groundtruth labels :param logits: network predictions :param data_format: either 'channels_first' of 'channels_last' :param epsilon: used for numerical instabilities caused by denominator = 0 :return: Tensor of mean generalized dice loss of all images of the batch """ assert (logits is None and logits_as_probability is not None) or ( logits is not None and logits_as_probability is None ), 'Set either logits or logits_as_probability, but not both.' channel_index = get_channel_index(labels, data_format) image_axes = get_image_axes(labels, data_format) labels_shape = labels.get_shape().as_list() num_labels = labels_shape[channel_index] # calculate logits propability as softmax (p_n) if logits_as_probability is None: logits_as_probability = tf.nn.softmax(logits, dim=channel_index) if weight_labels: # calculate label weights (w_l) label_weights = 1 / (tf.reduce_sum(labels, axis=image_axes)**2 + weight_epsilon) else: label_weights = 1 # GDL_b based on equation in reference paper numerator = tf.reduce_sum( label_weights * tf.reduce_sum(labels * logits_as_probability, axis=image_axes), axis=1) if squared: # square logits, no need to square labels, as they are either 0 or 1 denominator = tf.reduce_sum( label_weights * tf.reduce_sum(labels + (logits_as_probability**2), axis=image_axes), axis=1) else: denominator = tf.reduce_sum( label_weights * tf.reduce_sum(labels + logits_as_probability, axis=image_axes), axis=1) loss = 1 - 2 * (numerator + epsilon) / (denominator + epsilon) if weights is not None: channel_index = get_channel_index(weights, data_format) weights = tf.squeeze(weights, axis=channel_index) return reduce_mean_weighted(loss, weights) else: return tf.reduce_mean(loss)
def softmax_cross_entropy_with_logits(labels, logits, weights=None, data_format='channels_first'): channel_index = get_channel_index(labels, data_format) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits, dim=channel_index) if weights is not None: channel_index = get_channel_index(weights, data_format) weights = tf.squeeze(weights, axis=channel_index) return reduce_mean_weighted(loss, weights) else: return tf.reduce_mean(loss)
def weighted_spread_loss(labels, logits, m_low=0.2, m_hight=0.9, iteration_low_to_high=100000, global_step=100000, data_format='channels_first'): w_l = np.array([ 0.00705479, 0.03312549, 0.02664785, 0.4437354, 0.44254721, 0.04688926 ]) * 6 channel_index = get_channel_index(labels, data_format) m = m_low + (m_hight - m_low) * tf.minimum( tf.to_float(global_step / iteration_low_to_high), tf.to_float(1)) n_labels = labels.get_shape()[1] labels = tf.transpose(labels, [1, 0, 2, 3]) logits = tf.transpose(logits, [1, 0, 2, 3]) labels = tf.manip.reshape(labels, [n_labels, -1]) logits = tf.manip.reshape(logits, [n_labels, -1]) true_class_logits = tf.reduce_max(labels * logits, axis=0) margin_loss_pixel_class = tf.square( tf.nn.relu((m - true_class_logits + logits) * tf.abs(labels - 1))) loss = [] for i in range(len(w_l)): loss.append(w_l[i] * margin_loss_pixel_class * tf.gather(labels, [i], axis=channel_index)) loss = tf.reduce_sum(tf.stack(loss, axis=0), axis=0) loss = tf.reduce_mean(tf.reduce_sum(loss, axis=0)) return loss
def layer_norm(inputs, is_training, name='', data_format='channels_first'): with tf.variable_scope(name): inputs_shape = inputs.get_shape().as_list() channel_index = get_channel_index(inputs, data_format) params_shape = [1] * len(inputs_shape) params_shape[channel_index] = inputs_shape[channel_index] # Allocate parameters for the beta and gamma of the normalization. beta = tf.get_variable('beta', shape=params_shape, dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=is_training) gamma = tf.get_variable('gamma', shape=params_shape, dtype=tf.float32, initializer=tf.ones_initializer(), trainable=is_training) norm_axes = list(range(1, len(inputs_shape))) mean, variance = tf.nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function. outputs = tf.nn.batch_normalization(inputs, mean, variance, offset=beta, scale=gamma, variance_epsilon=1e-12) return outputs
def instance_norm(inputs, is_training, name='', data_format='channels_first', epsilon=1e-5, beta_initializer=tf.constant_initializer(0.0), gamma_initializer=tf.constant_initializer(1.0)): with tf.variable_scope(name): channel_index = get_channel_index(inputs, data_format) image_axes = get_image_axes(inputs, data_format=data_format) depth = inputs.get_shape()[channel_index] mean, variance = tf.nn.moments(inputs, axes=image_axes, keep_dims=True) inv = tf.rsqrt(variance + epsilon) normalized = (inputs - mean) * inv offset = tf.get_variable('offset', [depth], trainable=is_training, initializer=beta_initializer) scale = tf.get_variable('scale', [depth], trainable=is_training, initializer=gamma_initializer) offset_scale_shape = [1] * inputs.shape.ndims offset_scale_shape[channel_index] = depth offset = tf.reshape(offset, offset_scale_shape) scale = tf.reshape(scale, offset_scale_shape) return tf.identity(scale * normalized + offset, name='output')
def batch_norm(inputs, is_training, name='', data_format='channels_first'): # use faster fused batch_norm for 4 channel tensors if inputs.shape.ndims == 4 or inputs.shape.ndims == 5: channel_index = get_channel_index(inputs, data_format) return tf.layers.batch_normalization(inputs, axis=channel_index, name=name + '/bn', training=is_training) else: raise Exception('This batch_norm only supports images. Use batch_norm_dense or basic tensorflow version instead.')
def concat_channels(inputs, name='', data_format='channels_first', debug_print=debug_print_others): axis = get_channel_index(inputs[0], data_format) outputs = tf.concat(inputs, axis=axis, name=name) if debug_print: print_shape_parameters(inputs, outputs, name, 'concat') return outputs
def pad_for_conv(inputs, kernel_size, name, padding, data_format): if padding in ['symmetric', 'reflect']: # TODO check if this works for even kernels channel_index = get_channel_index(inputs, data_format) paddings = np.array([[0, 0]] + [[int(ks / 2)] * 2 for ks in kernel_size]) paddings = np.insert(paddings, channel_index, [0, 0], axis=0) outputs = tf.pad(inputs, paddings, mode=padding, name=name+'/pad') padding_for_conv = 'valid' else: outputs = inputs padding_for_conv = padding return outputs, padding_for_conv
def pad_for_conv(inputs, kernel_size, name, padding, data_format): if padding in ['symmetric', 'reflect']: # TODO check if this works for even kernels channel_index = get_channel_index(inputs, data_format) paddings = np.array([[0, 0]] + [[int(ks / 2)] * 2 for ks in kernel_size]) paddings = np.insert(paddings, channel_index, [0, 0], axis=0) outputs = tf.pad(inputs, paddings, mode=padding, name=name+'/pad') padding_for_conv = 'valid' elif padding == 'same_selu': # TODO check if this works for even kernels channel_index = get_channel_index(inputs, data_format) paddings = np.array([[0, 0]] + [[int(ks / 2)] * 2 for ks in kernel_size]) paddings = np.insert(paddings, channel_index, [0, 0], axis=0) # constant_values = -lambda * alpha from selu paper alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 pad_value = -alpha * scale outputs = tf.pad(inputs, paddings, mode='constant', constant_values=pad_value, name=name+'/pad') padding_for_conv = 'valid' else: outputs = inputs padding_for_conv = padding return outputs, padding_for_conv
def weighted_softmax(labels, logits, data_format='channels_first'): w_l = np.array([ 0.00705479, 0.03312549, 0.02664785, 0.4437354, 0.44254721, 0.04688926 ]) * 6 channel_index = get_channel_index(labels, data_format) loss_s = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits, dim=channel_index) loss = [] for i in range(len(w_l)): loss.append(w_l[i] * loss_s * tf.gather(labels, [i], axis=channel_index)) loss = tf.reduce_sum(tf.stack(loss, axis=0), axis=0) return tf.reduce_mean(loss)
def upsample_interpolation_function(inputs, factors, interpolation_function, support, name, data_format, padding): with tf.variable_scope(name): dim = len(factors) kernel = get_filler_kernel(interpolation_function, support, factors) kernel = kernel.reshape(kernel.shape + (1, 1)) padding = padding.upper() # set padding and cropping parameters cropping = False if padding == 'VALID_CROPPED': cropping = True padding = 'VALID' # calculate convolution parameters input_size = get_image_size(inputs, data_format) data_format_tf = get_tf_data_format(inputs, data_format) inputs_shape = inputs.get_shape().as_list() channel_axis = get_channel_index(inputs, data_format) num_inputs = inputs_shape[channel_axis] # calculate output_size dependent on padding parameter if padding == 'SAME': output_size = [input_size[i] * factors[i] for i in range(dim)] else: output_size = [ input_size[i] * factors[i] + kernel.shape[i] - factors[i] for i in range(dim) ] output_shape = get_tensor_shape(batch_size=inputs_shape[0], channel_size=1, image_size=output_size, data_format=data_format) # strides are the scaling factors strides = get_tensor_shape(batch_size=1, channel_size=1, image_size=factors, data_format=data_format) # actual tensorflow operations - channelwise! split_inputs = tf.split(inputs, num_inputs, axis=channel_axis, name='split') output_list = [] for i in range(len(split_inputs)): if dim == 2: current_output = tf.nn.conv2d_transpose( split_inputs[i], kernel, output_shape, strides, data_format=data_format_tf, name='conv' + str(i), padding=padding) else: # dim == 3 current_output = tf.nn.conv3d_transpose( split_inputs[i], kernel, output_shape, strides, data_format=data_format_tf, name='conv' + str(i), padding=padding) output_list.append(current_output) outputs = tf.concat(output_list, axis=channel_axis, name='concat') # make a final cropping, if specified if cropping: image_paddings = [ int((kernel.shape[i] - factors[i]) / 2) for i in range(dim) ] paddings = get_tensor_shape(batch_size=0, channel_size=0, image_size=image_paddings, data_format=data_format) output_size_cropped = [ input_size[i] * factors[i] for i in range(dim) ] outputs = tf.slice(outputs, paddings, [inputs_shape[0], inputs_shape[1]] + output_size_cropped) return tf.identity(outputs, name='output')
def softmax(labels, logits, weights=None, data_format='channels_first'): channel_index = get_channel_index(labels, data_format) loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits, dim=channel_index) return tf.reduce_mean(loss)
def network_scn(input, num_landmarks, is_training, data_format='channels_first'): num_filters = 128 local_kernel_size = [5, 5] spatial_kernel_size = [15, 15] downsampling_factor = 8 padding = 'same' kernel_initializer = he_initializer activation = tf.nn.relu heatmap_initializer = tf.truncated_normal_initializer(stddev=0.0001) local_activation = None spatial_activation = None with tf.variable_scope('local_appearance'): node = conv2d(input, num_filters, kernel_size=local_kernel_size, name='conv1', activation=activation, kernel_initializer=kernel_initializer, padding=padding, data_format=data_format, is_training=is_training) node = conv2d(node, num_filters, kernel_size=local_kernel_size, name='conv2', activation=activation, kernel_initializer=kernel_initializer, padding=padding, data_format=data_format, is_training=is_training) node = conv2d(node, num_filters, kernel_size=local_kernel_size, name='conv3', activation=activation, kernel_initializer=kernel_initializer, padding=padding, data_format=data_format, is_training=is_training) local_heatmaps = conv2d(node, num_landmarks, kernel_size=local_kernel_size, name='local_heatmaps', activation=local_activation, kernel_initializer=heatmap_initializer, padding=padding, data_format=data_format, is_training=is_training) with tf.variable_scope('spatial_configuration'): local_heatmaps_downsampled = avg_pool2d( local_heatmaps, [downsampling_factor, downsampling_factor], name='local_heatmaps_downsampled', data_format=data_format) channel_axis = get_channel_index(local_heatmaps_downsampled, data_format) local_heatmaps_downsampled_split = tf.split(local_heatmaps_downsampled, num_landmarks, channel_axis) spatial_heatmaps_downsampled_split = [] for i in range(num_landmarks): local_heatmaps_except_i = tf.concat([ local_heatmaps_downsampled_split[j] for j in range(num_landmarks) if i != j ], name='h_app_except_' + str(i), axis=channel_axis) h_acc = conv2d(local_heatmaps_except_i, 1, kernel_size=spatial_kernel_size, name='h_acc_' + str(i), activation=spatial_activation, kernel_initializer=heatmap_initializer, padding=padding, data_format=data_format, is_training=is_training) spatial_heatmaps_downsampled_split.append(h_acc) spatial_heatmaps_downsampled = tf.concat( spatial_heatmaps_downsampled_split, name='spatial_heatmaps_downsampled', axis=channel_axis) spatial_heatmaps = upsample2d_linear( spatial_heatmaps_downsampled, [downsampling_factor, downsampling_factor], name='spatial_prediction', padding='valid_cropped', data_format=data_format) with tf.variable_scope('combination'): heatmaps = local_heatmaps * spatial_heatmaps return heatmaps