예제 #1
0
    def loss(self, logits, labels, images):
        '''
        Loss to be minimised by the neural network
        :param logits: The output of the neural network before the softmax
        :param labels: The ground truth labels in standard (i.e. not one-hot) format
        :param images: The input
        :return: The total loss including weight decay, the loss without weight decay, only the weight decay 
        '''
        nlabels = self.config.nlabels
        loss_type = self.config.loss_type
        weight_decay = self.config.weight_decay
        loss_hyper_params = self.config.loss_hyper_params if hasattr(
            self.config, 'loss_hyper_param') else [1, 0.2]

        if nlabels > 2:
            oh_labels = tf.one_hot(labels, depth=nlabels)
        else:
            oh_labels = tf.cast(labels, tf.float32)

        with tf.variable_scope('weights_norm'):

            weights_norm = tf.reduce_sum(input_tensor=weight_decay * tf.stack([
                tf.nn.l2_loss(ii)
                for ii in tf.get_collection('weight_variables')
            ]),
                                         name='weights_norm')

        if loss_type == 'weighted_crossentropy':
            segmentation_loss = losses.pixel_wise_cross_entropy_loss_weighted(
                logits, oh_labels, class_weights=[0.1, 0.3, 0.3, 0.3])
        elif loss_type == 'crossentropy':
            segmentation_loss = losses.pixel_wise_cross_entropy_loss(
                logits, oh_labels)
        elif loss_type == 'dice':
            segmentation_loss = losses.dice_loss(logits,
                                                 oh_labels,
                                                 only_foreground=False)
        elif loss_type == 'dice_onlyfg':
            segmentation_loss = losses.dice_loss(logits,
                                                 oh_labels,
                                                 only_foreground=True)
        elif loss_type == 'crossentropy_and_dice':
            segmentation_loss = loss_hyper_params[0] * losses.pixel_wise_cross_entropy_loss(logits, oh_labels) \
                + loss_hyper_params[1] * losses.dice_loss(logits, oh_labels)
        elif loss_type == 'dice_cc':
            segmentation_loss = loss_hyper_params[0] * losses.dice_loss(logits, oh_labels) \
                + loss_hyper_params[1] * losses.connected_component_loss(logits, oh_labels)
        else:
            raise ValueError('Unknown loss: %s' % loss_type)

        total_loss = tf.add(segmentation_loss, weights_norm)

        return total_loss, segmentation_loss, weights_norm
예제 #2
0
def loss(logits, labels, nlabels, loss_type, weight_decay=0.0):
    '''
    Loss to be minimised by the neural network
    :param logits: The output of the neural network before the softmax
    :param labels: The ground truth labels in standard (i.e. not one-hot) format
    :param nlabels: The number of GT labels
    :param loss_type: Can be 'weighted_crossentropy'/'crossentropy'/'dice'/'dice_onlyfg'/'crossentropy_and_dice'
    :param weight_decay: The weight for the L2 regularisation of the network paramters
    :return: The total loss including weight decay, the loss without weight decay, only the weight decay 
    '''

    labels = tf.one_hot(labels, depth=nlabels)

    with tf.compat.v1.variable_scope('weights_norm'):

        weights_norm = tf.reduce_sum(input_tensor=weight_decay * tf.stack([
            tf.nn.l2_loss(ii)
            for ii in tf.compat.v1.get_collection('weight_variables')
        ]),
                                     name='weights_norm')

    if loss_type == 'weighted_crossentropy':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss_weighted(
            logits, labels, class_weights=[0.076, 0.308, 0.308, 0.308])
    elif loss_type == 'crossentropy':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(
            logits, labels)
    elif loss_type == 'dice':
        segmentation_loss = losses.dice_loss(logits,
                                             labels,
                                             only_foreground=False)
    elif loss_type == 'dice_onlyfg':
        segmentation_loss = losses.dice_loss(logits,
                                             labels,
                                             only_foreground=True)
    elif loss_type == 'crossentropy_and_dice':
        segmentation_loss = config.alfa * losses.pixel_wise_cross_entropy_loss_weighted(
            logits, labels, class_weights=[0.076, 0.308, 0.308, 0.308]
        ) + config.beta * losses.dice_loss(
            logits, labels, only_foreground=True)
    else:
        raise ValueError('Unknown loss: %s' % loss_type)

    total_loss = tf.add(segmentation_loss, weights_norm)

    return total_loss, segmentation_loss, weights_norm
    def loss(self):

        y_for_loss = tf.one_hot(self.y_pl, depth=self.nlabels)

        if self.exp_config.loss_type == 'crossentropy':
            task_loss = losses.cross_entropy_loss(logits=self.l_pl_,
                                                  labels=y_for_loss)
        elif self.exp_config.loss_type == 'dice_micro':
            task_loss = losses.dice_loss(logits=self.l_pl_,
                                         labels=y_for_loss,
                                         mode='micro')
        elif self.exp_config.loss_type == 'dice_macro':
            task_loss = losses.dice_loss(logits=self.l_pl_,
                                         labels=y_for_loss,
                                         mode='macro')
        elif self.exp_config.loss_type == 'dice_macro_robust':
            task_loss = losses.dice_loss(logits=self.l_pl_,
                                         labels=y_for_loss,
                                         mode='macro_robust')
        else:
            raise ValueError("Unknown loss_type in exp_config: '%s'" %
                             self.exp_config.loss_type)

        return task_loss
예제 #4
0
def loss(logits, labels, nlabels, loss_type):
    '''
    Loss to be minimised by the neural network
    :param logits: The output of the neural network before the softmax
    :param labels: The ground truth labels in standard (i.e. not one-hot) format
    :param nlabels: The number of GT labels
    :param loss_type: Can be 'crossentropy'/'dice'/'crossentropy_and_dice'
    :return: The segmentation
    '''

    labels = tf.one_hot(labels, depth=nlabels)

    if loss_type == 'crossentropy':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(logits, labels)
    elif loss_type == 'dice':
        segmentation_loss = losses.dice_loss(logits, labels)
    elif loss_type == 'crossentropy_and_dice':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(logits, labels) + 0.2*losses.dice_loss(logits, labels)
    else:
        raise ValueError('Unknown loss: %s' % loss_type)

    return segmentation_loss
예제 #5
0
def loss(logits,
         labels,
         nlabels,
         loss_type,
         mask_for_loss_within_mask=None,
         are_labels_1hot=False):
    '''
    Loss to be minimised by the neural network
    :param logits: The output of the neural network before the softmax
    :param labels: The ground truth labels in standard (i.e. not one-hot) format
    :param nlabels: The number of GT labels
    :param loss_type: Can be 'crossentropy'/'dice'/
    :return: The segmentation
    '''

    if are_labels_1hot is False:
        labels = tf.one_hot(labels, depth=nlabels)

    if loss_type == 'crossentropy':
        segmentation_loss = losses.pixel_wise_cross_entropy_loss(
            logits, labels)

    elif loss_type == 'crossentropy_reverse':
        predicted_probabilities = tf.nn.softmax(logits)
        segmentation_loss = losses.pixel_wise_cross_entropy_loss_using_probs(
            predicted_probabilities, labels)

    elif loss_type == 'dice':
        segmentation_loss = losses.dice_loss(logits, labels)

    elif loss_type == 'dice_within_mask':
        if mask_for_loss_within_mask is not None:
            segmentation_loss = losses.dice_loss_within_mask(
                logits, labels, mask_for_loss_within_mask)

    else:
        raise ValueError('Unknown loss: %s' % loss_type)

    return segmentation_loss