예제 #1
0
def generalized_dice_loss(labels,
                          logits=None,
                          logits_as_probability=None,
                          data_format='channels_first',
                          weights=None,
                          weight_labels=True,
                          squared=True,
                          weight_epsilon=1e-08,
                          epsilon=1e-08):
    """
    Taken from Generalised Dice overlap as a deep learning loss function for
    highly unbalanced segmentations (https://arxiv.org/pdf/1707.03237.pdf)
    :param labels: groundtruth labels
    :param logits: network predictions
    :param data_format: either 'channels_first' of 'channels_last'
    :param epsilon: used for numerical instabilities caused by denominator = 0
    :return: Tensor of mean generalized dice loss of all images of the batch
    """
    assert (logits is None and logits_as_probability is not None) or (
        logits is not None and logits_as_probability is None
    ), 'Set either logits or logits_as_probability, but not both.'
    channel_index = get_channel_index(labels, data_format)
    image_axes = get_image_axes(labels, data_format)
    labels_shape = labels.get_shape().as_list()
    num_labels = labels_shape[channel_index]
    # calculate logits propability as softmax (p_n)
    if logits_as_probability is None:
        logits_as_probability = tf.nn.softmax(logits, dim=channel_index)
    if weight_labels:
        # calculate label weights (w_l)
        label_weights = 1 / (tf.reduce_sum(labels, axis=image_axes)**2 +
                             weight_epsilon)
    else:
        label_weights = 1
    # GDL_b based on equation in reference paper
    numerator = tf.reduce_sum(
        label_weights *
        tf.reduce_sum(labels * logits_as_probability, axis=image_axes),
        axis=1)
    if squared:
        # square logits, no need to square labels, as they are either 0 or 1
        denominator = tf.reduce_sum(
            label_weights *
            tf.reduce_sum(labels +
                          (logits_as_probability**2), axis=image_axes),
            axis=1)
    else:
        denominator = tf.reduce_sum(
            label_weights *
            tf.reduce_sum(labels + logits_as_probability, axis=image_axes),
            axis=1)
    loss = 1 - 2 * (numerator + epsilon) / (denominator + epsilon)

    if weights is not None:
        channel_index = get_channel_index(weights, data_format)
        weights = tf.squeeze(weights, axis=channel_index)
        return reduce_mean_weighted(loss, weights)
    else:
        return tf.reduce_mean(loss)
예제 #2
0
def softmax_cross_entropy_with_logits(labels, logits, weights=None, data_format='channels_first'):
    channel_index = get_channel_index(labels, data_format)
    loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits, dim=channel_index)
    if weights is not None:
        channel_index = get_channel_index(weights, data_format)
        weights = tf.squeeze(weights, axis=channel_index)
        return reduce_mean_weighted(loss, weights)
    else:
        return tf.reduce_mean(loss)
예제 #3
0
def weighted_spread_loss(labels,
                         logits,
                         m_low=0.2,
                         m_hight=0.9,
                         iteration_low_to_high=100000,
                         global_step=100000,
                         data_format='channels_first'):
    w_l = np.array([
        0.00705479, 0.03312549, 0.02664785, 0.4437354, 0.44254721, 0.04688926
    ]) * 6
    channel_index = get_channel_index(labels, data_format)

    m = m_low + (m_hight - m_low) * tf.minimum(
        tf.to_float(global_step / iteration_low_to_high), tf.to_float(1))
    n_labels = labels.get_shape()[1]
    labels = tf.transpose(labels, [1, 0, 2, 3])
    logits = tf.transpose(logits, [1, 0, 2, 3])
    labels = tf.manip.reshape(labels, [n_labels, -1])
    logits = tf.manip.reshape(logits, [n_labels, -1])

    true_class_logits = tf.reduce_max(labels * logits, axis=0)
    margin_loss_pixel_class = tf.square(
        tf.nn.relu((m - true_class_logits + logits) * tf.abs(labels - 1)))

    loss = []
    for i in range(len(w_l)):
        loss.append(w_l[i] * margin_loss_pixel_class *
                    tf.gather(labels, [i], axis=channel_index))

    loss = tf.reduce_sum(tf.stack(loss, axis=0), axis=0)
    loss = tf.reduce_mean(tf.reduce_sum(loss, axis=0))

    return loss
예제 #4
0
파일: normalizers.py 프로젝트: auger1/MMWHS
def layer_norm(inputs, is_training, name='', data_format='channels_first'):
    with tf.variable_scope(name):
        inputs_shape = inputs.get_shape().as_list()
        channel_index = get_channel_index(inputs, data_format)
        params_shape = [1] * len(inputs_shape)
        params_shape[channel_index] = inputs_shape[channel_index]
        # Allocate parameters for the beta and gamma of the normalization.
        beta = tf.get_variable('beta',
                               shape=params_shape,
                               dtype=tf.float32,
                               initializer=tf.zeros_initializer(),
                               trainable=is_training)
        gamma = tf.get_variable('gamma',
                                shape=params_shape,
                                dtype=tf.float32,
                                initializer=tf.ones_initializer(),
                                trainable=is_training)
        norm_axes = list(range(1, len(inputs_shape)))
        mean, variance = tf.nn.moments(inputs, norm_axes, keep_dims=True)
        # Compute layer normalization using the batch_normalization function.
        outputs = tf.nn.batch_normalization(inputs,
                                            mean,
                                            variance,
                                            offset=beta,
                                            scale=gamma,
                                            variance_epsilon=1e-12)
        return outputs
예제 #5
0
파일: normalizers.py 프로젝트: auger1/MMWHS
def instance_norm(inputs,
                  is_training,
                  name='',
                  data_format='channels_first',
                  epsilon=1e-5,
                  beta_initializer=tf.constant_initializer(0.0),
                  gamma_initializer=tf.constant_initializer(1.0)):
    with tf.variable_scope(name):
        channel_index = get_channel_index(inputs, data_format)
        image_axes = get_image_axes(inputs, data_format=data_format)
        depth = inputs.get_shape()[channel_index]
        mean, variance = tf.nn.moments(inputs, axes=image_axes, keep_dims=True)
        inv = tf.rsqrt(variance + epsilon)
        normalized = (inputs - mean) * inv
        offset = tf.get_variable('offset', [depth],
                                 trainable=is_training,
                                 initializer=beta_initializer)
        scale = tf.get_variable('scale', [depth],
                                trainable=is_training,
                                initializer=gamma_initializer)
        offset_scale_shape = [1] * inputs.shape.ndims
        offset_scale_shape[channel_index] = depth
        offset = tf.reshape(offset, offset_scale_shape)
        scale = tf.reshape(scale, offset_scale_shape)
        return tf.identity(scale * normalized + offset, name='output')
예제 #6
0
def batch_norm(inputs, is_training, name='', data_format='channels_first'):
    # use faster fused batch_norm for 4 channel tensors
    if inputs.shape.ndims == 4 or inputs.shape.ndims == 5:
        channel_index = get_channel_index(inputs, data_format)
        return tf.layers.batch_normalization(inputs, axis=channel_index, name=name + '/bn', training=is_training)
    else:
        raise Exception('This batch_norm only supports images. Use batch_norm_dense or basic tensorflow version instead.')
def concat_channels(inputs,
                    name='',
                    data_format='channels_first',
                    debug_print=debug_print_others):
    axis = get_channel_index(inputs[0], data_format)
    outputs = tf.concat(inputs, axis=axis, name=name)
    if debug_print:
        print_shape_parameters(inputs, outputs, name, 'concat')
    return outputs
예제 #8
0
def pad_for_conv(inputs, kernel_size, name, padding, data_format):
    if padding in ['symmetric', 'reflect']:
        # TODO check if this works for even kernels
        channel_index = get_channel_index(inputs, data_format)
        paddings = np.array([[0, 0]] + [[int(ks / 2)] * 2 for ks in kernel_size])
        paddings = np.insert(paddings, channel_index, [0, 0], axis=0)
        outputs = tf.pad(inputs, paddings, mode=padding, name=name+'/pad')
        padding_for_conv = 'valid'
    else:
        outputs = inputs
        padding_for_conv = padding
    return outputs, padding_for_conv
예제 #9
0
def pad_for_conv(inputs, kernel_size, name, padding, data_format):
    if padding in ['symmetric', 'reflect']:
        # TODO check if this works for even kernels
        channel_index = get_channel_index(inputs, data_format)
        paddings = np.array([[0, 0]] + [[int(ks / 2)] * 2 for ks in kernel_size])
        paddings = np.insert(paddings, channel_index, [0, 0], axis=0)
        outputs = tf.pad(inputs, paddings, mode=padding, name=name+'/pad')
        padding_for_conv = 'valid'
    elif padding == 'same_selu':
        # TODO check if this works for even kernels
        channel_index = get_channel_index(inputs, data_format)
        paddings = np.array([[0, 0]] + [[int(ks / 2)] * 2 for ks in kernel_size])
        paddings = np.insert(paddings, channel_index, [0, 0], axis=0)
        # constant_values = -lambda * alpha from selu paper
        alpha = 1.6732632423543772848170429916717
        scale = 1.0507009873554804934193349852946
        pad_value = -alpha * scale
        outputs = tf.pad(inputs, paddings, mode='constant', constant_values=pad_value, name=name+'/pad')
        padding_for_conv = 'valid'
    else:
        outputs = inputs
        padding_for_conv = padding
    return outputs, padding_for_conv
예제 #10
0
def weighted_softmax(labels, logits, data_format='channels_first'):
    w_l = np.array([
        0.00705479, 0.03312549, 0.02664785, 0.4437354, 0.44254721, 0.04688926
    ]) * 6
    channel_index = get_channel_index(labels, data_format)
    loss_s = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels,
                                                        logits=logits,
                                                        dim=channel_index)

    loss = []
    for i in range(len(w_l)):
        loss.append(w_l[i] * loss_s *
                    tf.gather(labels, [i], axis=channel_index))

    loss = tf.reduce_sum(tf.stack(loss, axis=0), axis=0)
    return tf.reduce_mean(loss)
예제 #11
0
def upsample_interpolation_function(inputs, factors, interpolation_function,
                                    support, name, data_format, padding):
    with tf.variable_scope(name):
        dim = len(factors)
        kernel = get_filler_kernel(interpolation_function, support, factors)
        kernel = kernel.reshape(kernel.shape + (1, 1))
        padding = padding.upper()

        # set padding and cropping parameters
        cropping = False
        if padding == 'VALID_CROPPED':
            cropping = True
            padding = 'VALID'

        # calculate convolution parameters
        input_size = get_image_size(inputs, data_format)
        data_format_tf = get_tf_data_format(inputs, data_format)
        inputs_shape = inputs.get_shape().as_list()
        channel_axis = get_channel_index(inputs, data_format)
        num_inputs = inputs_shape[channel_axis]

        # calculate output_size dependent on padding parameter
        if padding == 'SAME':
            output_size = [input_size[i] * factors[i] for i in range(dim)]
        else:
            output_size = [
                input_size[i] * factors[i] + kernel.shape[i] - factors[i]
                for i in range(dim)
            ]
        output_shape = get_tensor_shape(batch_size=inputs_shape[0],
                                        channel_size=1,
                                        image_size=output_size,
                                        data_format=data_format)

        # strides are the scaling factors
        strides = get_tensor_shape(batch_size=1,
                                   channel_size=1,
                                   image_size=factors,
                                   data_format=data_format)

        # actual tensorflow operations - channelwise!
        split_inputs = tf.split(inputs,
                                num_inputs,
                                axis=channel_axis,
                                name='split')
        output_list = []
        for i in range(len(split_inputs)):
            if dim == 2:
                current_output = tf.nn.conv2d_transpose(
                    split_inputs[i],
                    kernel,
                    output_shape,
                    strides,
                    data_format=data_format_tf,
                    name='conv' + str(i),
                    padding=padding)
            else:  # dim == 3
                current_output = tf.nn.conv3d_transpose(
                    split_inputs[i],
                    kernel,
                    output_shape,
                    strides,
                    data_format=data_format_tf,
                    name='conv' + str(i),
                    padding=padding)
            output_list.append(current_output)
        outputs = tf.concat(output_list, axis=channel_axis, name='concat')

        # make a final cropping, if specified
        if cropping:
            image_paddings = [
                int((kernel.shape[i] - factors[i]) / 2) for i in range(dim)
            ]
            paddings = get_tensor_shape(batch_size=0,
                                        channel_size=0,
                                        image_size=image_paddings,
                                        data_format=data_format)
            output_size_cropped = [
                input_size[i] * factors[i] for i in range(dim)
            ]
            outputs = tf.slice(outputs, paddings,
                               [inputs_shape[0], inputs_shape[1]] +
                               output_size_cropped)

        return tf.identity(outputs, name='output')
예제 #12
0
def softmax(labels, logits, weights=None, data_format='channels_first'):
    channel_index = get_channel_index(labels, data_format)
    loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels,
                                                      logits=logits,
                                                      dim=channel_index)
    return tf.reduce_mean(loss)
예제 #13
0
def network_scn(input,
                num_landmarks,
                is_training,
                data_format='channels_first'):
    num_filters = 128
    local_kernel_size = [5, 5]
    spatial_kernel_size = [15, 15]
    downsampling_factor = 8
    padding = 'same'
    kernel_initializer = he_initializer
    activation = tf.nn.relu
    heatmap_initializer = tf.truncated_normal_initializer(stddev=0.0001)
    local_activation = None
    spatial_activation = None
    with tf.variable_scope('local_appearance'):
        node = conv2d(input,
                      num_filters,
                      kernel_size=local_kernel_size,
                      name='conv1',
                      activation=activation,
                      kernel_initializer=kernel_initializer,
                      padding=padding,
                      data_format=data_format,
                      is_training=is_training)
        node = conv2d(node,
                      num_filters,
                      kernel_size=local_kernel_size,
                      name='conv2',
                      activation=activation,
                      kernel_initializer=kernel_initializer,
                      padding=padding,
                      data_format=data_format,
                      is_training=is_training)
        node = conv2d(node,
                      num_filters,
                      kernel_size=local_kernel_size,
                      name='conv3',
                      activation=activation,
                      kernel_initializer=kernel_initializer,
                      padding=padding,
                      data_format=data_format,
                      is_training=is_training)
        local_heatmaps = conv2d(node,
                                num_landmarks,
                                kernel_size=local_kernel_size,
                                name='local_heatmaps',
                                activation=local_activation,
                                kernel_initializer=heatmap_initializer,
                                padding=padding,
                                data_format=data_format,
                                is_training=is_training)
    with tf.variable_scope('spatial_configuration'):
        local_heatmaps_downsampled = avg_pool2d(
            local_heatmaps, [downsampling_factor, downsampling_factor],
            name='local_heatmaps_downsampled',
            data_format=data_format)
        channel_axis = get_channel_index(local_heatmaps_downsampled,
                                         data_format)
        local_heatmaps_downsampled_split = tf.split(local_heatmaps_downsampled,
                                                    num_landmarks,
                                                    channel_axis)
        spatial_heatmaps_downsampled_split = []
        for i in range(num_landmarks):
            local_heatmaps_except_i = tf.concat([
                local_heatmaps_downsampled_split[j]
                for j in range(num_landmarks) if i != j
            ],
                                                name='h_app_except_' + str(i),
                                                axis=channel_axis)
            h_acc = conv2d(local_heatmaps_except_i,
                           1,
                           kernel_size=spatial_kernel_size,
                           name='h_acc_' + str(i),
                           activation=spatial_activation,
                           kernel_initializer=heatmap_initializer,
                           padding=padding,
                           data_format=data_format,
                           is_training=is_training)
            spatial_heatmaps_downsampled_split.append(h_acc)
        spatial_heatmaps_downsampled = tf.concat(
            spatial_heatmaps_downsampled_split,
            name='spatial_heatmaps_downsampled',
            axis=channel_axis)
        spatial_heatmaps = upsample2d_linear(
            spatial_heatmaps_downsampled,
            [downsampling_factor, downsampling_factor],
            name='spatial_prediction',
            padding='valid_cropped',
            data_format=data_format)
    with tf.variable_scope('combination'):
        heatmaps = local_heatmaps * spatial_heatmaps
    return heatmaps