Exemplo n.º 1
0
    def loss(self, logits, labels, images):
        '''
        Loss to be minimised by the neural network
        :param logits: The output of the neural network before the softmax
        :param labels: The ground truth labels in standard (i.e. not one-hot) format
        :param images: The input
        :return: The total loss including weight decay, the loss without weight decay, only the weight decay 
        '''
        nlabels = self.config.nlabels
        loss_type = self.config.loss_type
        weight_decay = self.config.weight_decay
        loss_hyper_params = self.config.loss_hyper_params if hasattr(
            self.config, 'loss_hyper_param') else [1, 0.2]

        if nlabels > 2:
            oh_labels = tf.one_hot(labels, depth=nlabels)
        else:
            oh_labels = tf.cast(labels, tf.float32)

        with tf.variable_scope('weights_norm'):

            weights_norm = tf.reduce_sum(input_tensor=weight_decay * tf.stack([
                tf.nn.l2_loss(ii)
                for ii in tf.get_collection('weight_variables')
            ]),
                                         name='weights_norm')

        if loss_type == 'weighted_crossentropy':
            segmentation_loss = losses.pixel_wise_cross_entropy_loss_weighted(
                logits, oh_labels, class_weights=[0.1, 0.3, 0.3, 0.3])
        elif loss_type == 'crossentropy':
            segmentation_loss = losses.pixel_wise_cross_entropy_loss(
                logits, oh_labels)
        elif loss_type == 'dice':
            segmentation_loss = losses.dice_loss(logits,
                                                 oh_labels,
                                                 only_foreground=False)
        elif loss_type == 'dice_onlyfg':
            segmentation_loss = losses.dice_loss(logits,
                                                 oh_labels,
                                                 only_foreground=True)
        elif loss_type == 'crossentropy_and_dice':
            segmentation_loss = loss_hyper_params[0] * losses.pixel_wise_cross_entropy_loss(logits, oh_labels) \
                + loss_hyper_params[1] * losses.dice_loss(logits, oh_labels)
        elif loss_type == 'dice_cc':
            segmentation_loss = loss_hyper_params[0] * losses.dice_loss(logits, oh_labels) \
                + loss_hyper_params[1] * losses.connected_component_loss(logits, oh_labels)
        else:
            raise ValueError('Unknown loss: %s' % loss_type)

        total_loss = tf.add(segmentation_loss, weights_norm)

        return total_loss, segmentation_loss, weights_norm
Exemplo n.º 2
0
def loss(logits, labels, nlabels, loss_type, weight_decay=0.0):
    '''
    Loss to be minimised by the neural network
    :param logits: The output of the neural network before the softmax
    :param labels: The ground truth labels in standard (i.e. not one-hot) format
    :param nlabels: The number of GT labels
    :param loss_type: Can be 'weighted_crossentropy'/'crossentropy'/'dice'/'dice_onlyfg'/'crossentropy_and_dice'
    :param weight_decay: The weight for the L2 regularisation of the network paramters
    :return: The total loss including weight decay, the loss without weight decay, only the weight decay 
    '''

    labels = tf.one_hot(labels, depth=nlabels)

    with tf.variable_scope('weights_norm'):

        weights_norm = tf.reduce_sum(input_tensor=weight_decay * tf.stack([
            tf.nn.l2_loss(ii) for ii in tf.get_collection('weight_variables')
        ]),
                                     name='weights_norm')

    if loss_type == 'weighted_crossentropy':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss_weighted(
            logits, labels, class_weights=[0.1, 0.3, 0.3, 0.3])
    elif loss_type == 'crossentropy':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(
            logits, labels)
    elif loss_type == 'dice':
        segmentation_loss = losses.dice_loss(logits,
                                             labels,
                                             only_foreground=False)
    elif loss_type == 'dice_onlyfg':
        segmentation_loss = losses.dice_loss(logits,
                                             labels,
                                             only_foreground=True)
    elif loss_type == 'crossentropy_and_dice':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(
            logits, labels) + 0.2 * losses.dice_loss(logits, labels)
    else:
        raise ValueError('Unknown loss: %s' % loss_type)

    total_loss = tf.add(segmentation_loss, weights_norm)

    return total_loss, segmentation_loss, weights_norm
Exemplo n.º 3
0
def loss(logits, labels, nlabels, loss_type):
    '''
    Loss to be minimised by the neural network
    :param logits: The output of the neural network before the softmax
    :param labels: The ground truth labels in standard (i.e. not one-hot) format
    :param nlabels: The number of GT labels
    :param loss_type: Can be 'crossentropy'/'dice'/'crossentropy_and_dice'
    :return: The segmentation
    '''

    labels = tf.one_hot(labels, depth=nlabels)

    if loss_type == 'crossentropy':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(logits, labels)
    elif loss_type == 'dice':
        segmentation_loss = losses.dice_loss(logits, labels)
    elif loss_type == 'crossentropy_and_dice':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(logits, labels) + 0.2*losses.dice_loss(logits, labels)
    else:
        raise ValueError('Unknown loss: %s' % loss_type)

    return segmentation_loss
Exemplo n.º 4
0
def loss(logits,
         labels,
         nlabels,
         loss_type,
         mask_for_loss_within_mask=None,
         are_labels_1hot=False):
    '''
    Loss to be minimised by the neural network
    :param logits: The output of the neural network before the softmax
    :param labels: The ground truth labels in standard (i.e. not one-hot) format
    :param nlabels: The number of GT labels
    :param loss_type: Can be 'crossentropy'/'dice'/
    :return: The segmentation
    '''

    if are_labels_1hot is False:
        labels = tf.one_hot(labels, depth=nlabels)

    if loss_type == 'crossentropy':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(
            logits, labels)

    elif loss_type == 'crossentropy_reverse':
        predicted_probabilities = tf.nn.softmax(logits)
        segmentation_loss = losses.pixel_wise_cross_entropy_loss_using_probs(
            predicted_probabilities, labels)

    elif loss_type == 'dice':
        segmentation_loss = losses.dice_loss(logits, labels)

    elif loss_type == 'dice_within_mask':
        if mask_for_loss_within_mask is not None:
            segmentation_loss = losses.dice_loss_within_mask(
                logits, labels, mask_for_loss_within_mask)

    else:
        raise ValueError('Unknown loss: %s' % loss_type)

    return segmentation_loss