예제 #1
0
파일: normalizers.py 프로젝트: auger1/MMWHS
def instance_norm(inputs,
                  is_training,
                  name='',
                  data_format='channels_first',
                  epsilon=1e-5,
                  beta_initializer=tf.constant_initializer(0.0),
                  gamma_initializer=tf.constant_initializer(1.0)):
    with tf.variable_scope(name):
        channel_index = get_channel_index(inputs, data_format)
        image_axes = get_image_axes(inputs, data_format=data_format)
        depth = inputs.get_shape()[channel_index]
        mean, variance = tf.nn.moments(inputs, axes=image_axes, keep_dims=True)
        inv = tf.rsqrt(variance + epsilon)
        normalized = (inputs - mean) * inv
        offset = tf.get_variable('offset', [depth],
                                 trainable=is_training,
                                 initializer=beta_initializer)
        scale = tf.get_variable('scale', [depth],
                                trainable=is_training,
                                initializer=gamma_initializer)
        offset_scale_shape = [1] * inputs.shape.ndims
        offset_scale_shape[channel_index] = depth
        offset = tf.reshape(offset, offset_scale_shape)
        scale = tf.reshape(scale, offset_scale_shape)
        return tf.identity(scale * normalized + offset, name='output')
예제 #2
0
def generalized_dice_loss(labels,
                          logits=None,
                          logits_as_probability=None,
                          data_format='channels_first',
                          weights=None,
                          weight_labels=True,
                          squared=True,
                          weight_epsilon=1e-08,
                          epsilon=1e-08):
    """
    Taken from Generalised Dice overlap as a deep learning loss function for
    highly unbalanced segmentations (https://arxiv.org/pdf/1707.03237.pdf)
    :param labels: groundtruth labels
    :param logits: network predictions
    :param data_format: either 'channels_first' of 'channels_last'
    :param epsilon: used for numerical instabilities caused by denominator = 0
    :return: Tensor of mean generalized dice loss of all images of the batch
    """
    assert (logits is None and logits_as_probability is not None) or (
        logits is not None and logits_as_probability is None
    ), 'Set either logits or logits_as_probability, but not both.'
    channel_index = get_channel_index(labels, data_format)
    image_axes = get_image_axes(labels, data_format)
    labels_shape = labels.get_shape().as_list()
    num_labels = labels_shape[channel_index]
    # calculate logits propability as softmax (p_n)
    if logits_as_probability is None:
        logits_as_probability = tf.nn.softmax(logits, dim=channel_index)
    if weight_labels:
        # calculate label weights (w_l)
        label_weights = 1 / (tf.reduce_sum(labels, axis=image_axes)**2 +
                             weight_epsilon)
    else:
        label_weights = 1
    # GDL_b based on equation in reference paper
    numerator = tf.reduce_sum(
        label_weights *
        tf.reduce_sum(labels * logits_as_probability, axis=image_axes),
        axis=1)
    if squared:
        # square logits, no need to square labels, as they are either 0 or 1
        denominator = tf.reduce_sum(
            label_weights *
            tf.reduce_sum(labels +
                          (logits_as_probability**2), axis=image_axes),
            axis=1)
    else:
        denominator = tf.reduce_sum(
            label_weights *
            tf.reduce_sum(labels + logits_as_probability, axis=image_axes),
            axis=1)
    loss = 1 - 2 * (numerator + epsilon) / (denominator + epsilon)

    if weights is not None:
        channel_index = get_channel_index(weights, data_format)
        weights = tf.squeeze(weights, axis=channel_index)
        return reduce_mean_weighted(loss, weights)
    else:
        return tf.reduce_mean(loss)
예제 #3
0
def cosine_embedding_per_instance_loss(embeddings,
                                       instances,
                                       normalize=False,
                                       l=1.0,
                                       term_1_squared=False,
                                       term_2_factor=0,
                                       data_format='channels_first',
                                       parallel_iterations=4):
    image_axes = get_image_axes(embeddings, data_format)
    #image_axes = [i + 1 for i in image_axes]
    if data_format == 'channels_first':
        embedding_axis = 1
    else:
        embedding_axis = len(embeddings.shape) - 1
    if len(embeddings.shape) == 5:
        instances_transposed = tf.expand_dims(tf.transpose(
            instances, [1, 0, 2, 3, 4]),
                                              axis=2)
    else:
        instances_transposed = tf.expand_dims(tf.transpose(
            instances, [1, 0, 2, 3]),
                                              axis=2)
    print(instances_transposed.shape)
    embeddings_norm = tf.nn.l2_normalize(embeddings, dim=embedding_axis)
    per_instance_loss = lambda i: cosine_embedding_single_instance_loss(
        embeddings,
        tf.equal(i, 1),
        tf.equal(i, 2),
        embeddings_norm=embeddings_norm,
        normalize=normalize,
        l=l,
        term_1_squared=term_1_squared,
        data_format=data_format,
        term_2_factor=term_2_factor)
    loss_list = tf.map_fn(per_instance_loss,
                          instances_transposed,
                          swap_memory=True,
                          dtype=tf.float32,
                          parallel_iterations=parallel_iterations)
    return tf.reduce_mean(loss_list)
예제 #4
0
def cosine_embedding_single_instance_loss(embeddings,
                                          target_instances_mask,
                                          other_instances_mask,
                                          embeddings_norm=None,
                                          normalize=False,
                                          l=1.0,
                                          term_1_squared=False,
                                          term_2_factor=0,
                                          use_first_frame_for_mean=False,
                                          data_format='channels_first'):
    image_axes = get_image_axes(embeddings, data_format)
    image_axes = [i for i in image_axes]
    if data_format == 'channels_first':
        embedding_axis = 1
        frame_axis = 2
    else:
        embedding_axis = len(embeddings.shape) - 1
        frame_axis = 1
    # expand axis, such that embeddings and instances are in different dimensions
    # create target and other instances pixel masks
    #target_instances_mask = tf.equal(instances, 1)
    #other_instances_mask = tf.equal(instances, 2)
    # calculate mean embedding for target pixels
    if use_first_frame_for_mean:
        slices = [slice(None)] * 4
        slices.insert(frame_axis, slice(0, 1))
        h = reduce_mean_masked(embeddings[slices],
                               target_instances_mask[slices],
                               axis=image_axes,
                               keepdims=True)
    else:
        h = reduce_mean_masked(embeddings,
                               target_instances_mask,
                               axis=image_axes,
                               keepdims=True)
    if embeddings_norm is None:
        embeddings_norm = tf.nn.l2_normalize(embeddings, dim=embedding_axis)
    # l2_normalize embeddings -> needed for cos_simliarity
    h_norm = tf.nn.l2_normalize(h, dim=embedding_axis)
    # calculate cos_similarity with target mean embedding and all embeddings
    cos_similarity = tf.reduce_sum(h_norm * embeddings_norm,
                                   axis=embedding_axis,
                                   keepdims=True)
    # term_0: target mean embedding and target pixel embeddings should be as similar as possible
    term_0 = 1 - cos_similarity
    if term_1_squared:
        # term_1: target mean embedding and other pixel embeddings should be orthogonal (== 0)
        term_1 = cos_similarity**2
    else:
        # term_1: target mean embedding and other pixel embeddings should be far apart (>= 0)
        term_1 = tf.nn.relu(cos_similarity)

    # either reduce_mean or reduce_sum on target and other pixel masks
    if normalize:
        term_0 = reduce_mean_masked(term_0, target_instances_mask)
        term_1 = reduce_mean_masked(term_1, other_instances_mask)
    else:
        term_0 = reduce_sum_masked(term_0, target_instances_mask)
        term_1 = reduce_sum_masked(term_1, other_instances_mask)

    term_2 = 0
    if term_2_factor > 0:
        instance_mask = tf.reduce_any(target_instances_mask,
                                      axis=image_axes,
                                      keepdims=True)
        term_2 = tf.norm(h_norm, ord=1, axis=embedding_axis, keepdims=True)
        term_2 = reduce_mean_masked(term_2, instance_mask) * term_2_factor

    loss = term_0 + l * term_1 + term_2
    return loss