def instance_norm(inputs, is_training, name='', data_format='channels_first', epsilon=1e-5, beta_initializer=tf.constant_initializer(0.0), gamma_initializer=tf.constant_initializer(1.0)): with tf.variable_scope(name): channel_index = get_channel_index(inputs, data_format) image_axes = get_image_axes(inputs, data_format=data_format) depth = inputs.get_shape()[channel_index] mean, variance = tf.nn.moments(inputs, axes=image_axes, keep_dims=True) inv = tf.rsqrt(variance + epsilon) normalized = (inputs - mean) * inv offset = tf.get_variable('offset', [depth], trainable=is_training, initializer=beta_initializer) scale = tf.get_variable('scale', [depth], trainable=is_training, initializer=gamma_initializer) offset_scale_shape = [1] * inputs.shape.ndims offset_scale_shape[channel_index] = depth offset = tf.reshape(offset, offset_scale_shape) scale = tf.reshape(scale, offset_scale_shape) return tf.identity(scale * normalized + offset, name='output')
def generalized_dice_loss(labels, logits=None, logits_as_probability=None, data_format='channels_first', weights=None, weight_labels=True, squared=True, weight_epsilon=1e-08, epsilon=1e-08): """ Taken from Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations (https://arxiv.org/pdf/1707.03237.pdf) :param labels: groundtruth labels :param logits: network predictions :param data_format: either 'channels_first' of 'channels_last' :param epsilon: used for numerical instabilities caused by denominator = 0 :return: Tensor of mean generalized dice loss of all images of the batch """ assert (logits is None and logits_as_probability is not None) or ( logits is not None and logits_as_probability is None ), 'Set either logits or logits_as_probability, but not both.' channel_index = get_channel_index(labels, data_format) image_axes = get_image_axes(labels, data_format) labels_shape = labels.get_shape().as_list() num_labels = labels_shape[channel_index] # calculate logits propability as softmax (p_n) if logits_as_probability is None: logits_as_probability = tf.nn.softmax(logits, dim=channel_index) if weight_labels: # calculate label weights (w_l) label_weights = 1 / (tf.reduce_sum(labels, axis=image_axes)**2 + weight_epsilon) else: label_weights = 1 # GDL_b based on equation in reference paper numerator = tf.reduce_sum( label_weights * tf.reduce_sum(labels * logits_as_probability, axis=image_axes), axis=1) if squared: # square logits, no need to square labels, as they are either 0 or 1 denominator = tf.reduce_sum( label_weights * tf.reduce_sum(labels + (logits_as_probability**2), axis=image_axes), axis=1) else: denominator = tf.reduce_sum( label_weights * tf.reduce_sum(labels + logits_as_probability, axis=image_axes), axis=1) loss = 1 - 2 * (numerator + epsilon) / (denominator + epsilon) if weights is not None: channel_index = get_channel_index(weights, data_format) weights = tf.squeeze(weights, axis=channel_index) return reduce_mean_weighted(loss, weights) else: return tf.reduce_mean(loss)
def cosine_embedding_per_instance_loss(embeddings, instances, normalize=False, l=1.0, term_1_squared=False, term_2_factor=0, data_format='channels_first', parallel_iterations=4): image_axes = get_image_axes(embeddings, data_format) #image_axes = [i + 1 for i in image_axes] if data_format == 'channels_first': embedding_axis = 1 else: embedding_axis = len(embeddings.shape) - 1 if len(embeddings.shape) == 5: instances_transposed = tf.expand_dims(tf.transpose( instances, [1, 0, 2, 3, 4]), axis=2) else: instances_transposed = tf.expand_dims(tf.transpose( instances, [1, 0, 2, 3]), axis=2) print(instances_transposed.shape) embeddings_norm = tf.nn.l2_normalize(embeddings, dim=embedding_axis) per_instance_loss = lambda i: cosine_embedding_single_instance_loss( embeddings, tf.equal(i, 1), tf.equal(i, 2), embeddings_norm=embeddings_norm, normalize=normalize, l=l, term_1_squared=term_1_squared, data_format=data_format, term_2_factor=term_2_factor) loss_list = tf.map_fn(per_instance_loss, instances_transposed, swap_memory=True, dtype=tf.float32, parallel_iterations=parallel_iterations) return tf.reduce_mean(loss_list)
def cosine_embedding_single_instance_loss(embeddings, target_instances_mask, other_instances_mask, embeddings_norm=None, normalize=False, l=1.0, term_1_squared=False, term_2_factor=0, use_first_frame_for_mean=False, data_format='channels_first'): image_axes = get_image_axes(embeddings, data_format) image_axes = [i for i in image_axes] if data_format == 'channels_first': embedding_axis = 1 frame_axis = 2 else: embedding_axis = len(embeddings.shape) - 1 frame_axis = 1 # expand axis, such that embeddings and instances are in different dimensions # create target and other instances pixel masks #target_instances_mask = tf.equal(instances, 1) #other_instances_mask = tf.equal(instances, 2) # calculate mean embedding for target pixels if use_first_frame_for_mean: slices = [slice(None)] * 4 slices.insert(frame_axis, slice(0, 1)) h = reduce_mean_masked(embeddings[slices], target_instances_mask[slices], axis=image_axes, keepdims=True) else: h = reduce_mean_masked(embeddings, target_instances_mask, axis=image_axes, keepdims=True) if embeddings_norm is None: embeddings_norm = tf.nn.l2_normalize(embeddings, dim=embedding_axis) # l2_normalize embeddings -> needed for cos_simliarity h_norm = tf.nn.l2_normalize(h, dim=embedding_axis) # calculate cos_similarity with target mean embedding and all embeddings cos_similarity = tf.reduce_sum(h_norm * embeddings_norm, axis=embedding_axis, keepdims=True) # term_0: target mean embedding and target pixel embeddings should be as similar as possible term_0 = 1 - cos_similarity if term_1_squared: # term_1: target mean embedding and other pixel embeddings should be orthogonal (== 0) term_1 = cos_similarity**2 else: # term_1: target mean embedding and other pixel embeddings should be far apart (>= 0) term_1 = tf.nn.relu(cos_similarity) # either reduce_mean or reduce_sum on target and other pixel masks if normalize: term_0 = reduce_mean_masked(term_0, target_instances_mask) term_1 = reduce_mean_masked(term_1, other_instances_mask) else: term_0 = reduce_sum_masked(term_0, target_instances_mask) term_1 = reduce_sum_masked(term_1, other_instances_mask) term_2 = 0 if term_2_factor > 0: instance_mask = tf.reduce_any(target_instances_mask, axis=image_axes, keepdims=True) term_2 = tf.norm(h_norm, ord=1, axis=embedding_axis, keepdims=True) term_2 = reduce_mean_masked(term_2, instance_mask) * term_2_factor loss = term_0 + l * term_1 + term_2 return loss