def loss(self, logits, labels, images): ''' Loss to be minimised by the neural network :param logits: The output of the neural network before the softmax :param labels: The ground truth labels in standard (i.e. not one-hot) format :param images: The input :return: The total loss including weight decay, the loss without weight decay, only the weight decay ''' nlabels = self.config.nlabels loss_type = self.config.loss_type weight_decay = self.config.weight_decay loss_hyper_params = self.config.loss_hyper_params if hasattr( self.config, 'loss_hyper_param') else [1, 0.2] if nlabels > 2: oh_labels = tf.one_hot(labels, depth=nlabels) else: oh_labels = tf.cast(labels, tf.float32) with tf.variable_scope('weights_norm'): weights_norm = tf.reduce_sum(input_tensor=weight_decay * tf.stack([ tf.nn.l2_loss(ii) for ii in tf.get_collection('weight_variables') ]), name='weights_norm') if loss_type == 'weighted_crossentropy': segmentation_loss = losses.pixel_wise_cross_entropy_loss_weighted( logits, oh_labels, class_weights=[0.1, 0.3, 0.3, 0.3]) elif loss_type == 'crossentropy': segmentation_loss = losses.pixel_wise_cross_entropy_loss( logits, oh_labels) elif loss_type == 'dice': segmentation_loss = losses.dice_loss(logits, oh_labels, only_foreground=False) elif loss_type == 'dice_onlyfg': segmentation_loss = losses.dice_loss(logits, oh_labels, only_foreground=True) elif loss_type == 'crossentropy_and_dice': segmentation_loss = loss_hyper_params[0] * losses.pixel_wise_cross_entropy_loss(logits, oh_labels) \ + loss_hyper_params[1] * losses.dice_loss(logits, oh_labels) elif loss_type == 'dice_cc': segmentation_loss = loss_hyper_params[0] * losses.dice_loss(logits, oh_labels) \ + loss_hyper_params[1] * losses.connected_component_loss(logits, oh_labels) else: raise ValueError('Unknown loss: %s' % loss_type) total_loss = tf.add(segmentation_loss, weights_norm) return total_loss, segmentation_loss, weights_norm
def loss(logits, labels, nlabels, loss_type, weight_decay=0.0): ''' Loss to be minimised by the neural network :param logits: The output of the neural network before the softmax :param labels: The ground truth labels in standard (i.e. not one-hot) format :param nlabels: The number of GT labels :param loss_type: Can be 'weighted_crossentropy'/'crossentropy'/'dice'/'dice_onlyfg'/'crossentropy_and_dice' :param weight_decay: The weight for the L2 regularisation of the network paramters :return: The total loss including weight decay, the loss without weight decay, only the weight decay ''' labels = tf.one_hot(labels, depth=nlabels) with tf.compat.v1.variable_scope('weights_norm'): weights_norm = tf.reduce_sum(input_tensor=weight_decay * tf.stack([ tf.nn.l2_loss(ii) for ii in tf.compat.v1.get_collection('weight_variables') ]), name='weights_norm') if loss_type == 'weighted_crossentropy': segmentation_loss = losses.pixel_wise_cross_entropy_loss_weighted( logits, labels, class_weights=[0.076, 0.308, 0.308, 0.308]) elif loss_type == 'crossentropy': segmentation_loss = losses.pixel_wise_cross_entropy_loss( logits, labels) elif loss_type == 'dice': segmentation_loss = losses.dice_loss(logits, labels, only_foreground=False) elif loss_type == 'dice_onlyfg': segmentation_loss = losses.dice_loss(logits, labels, only_foreground=True) elif loss_type == 'crossentropy_and_dice': segmentation_loss = config.alfa * losses.pixel_wise_cross_entropy_loss_weighted( logits, labels, class_weights=[0.076, 0.308, 0.308, 0.308] ) + config.beta * losses.dice_loss( logits, labels, only_foreground=True) else: raise ValueError('Unknown loss: %s' % loss_type) total_loss = tf.add(segmentation_loss, weights_norm) return total_loss, segmentation_loss, weights_norm
def loss(self): y_for_loss = tf.one_hot(self.y_pl, depth=self.nlabels) if self.exp_config.loss_type == 'crossentropy': task_loss = losses.cross_entropy_loss(logits=self.l_pl_, labels=y_for_loss) elif self.exp_config.loss_type == 'dice_micro': task_loss = losses.dice_loss(logits=self.l_pl_, labels=y_for_loss, mode='micro') elif self.exp_config.loss_type == 'dice_macro': task_loss = losses.dice_loss(logits=self.l_pl_, labels=y_for_loss, mode='macro') elif self.exp_config.loss_type == 'dice_macro_robust': task_loss = losses.dice_loss(logits=self.l_pl_, labels=y_for_loss, mode='macro_robust') else: raise ValueError("Unknown loss_type in exp_config: '%s'" % self.exp_config.loss_type) return task_loss
def loss(logits, labels, nlabels, loss_type): ''' Loss to be minimised by the neural network :param logits: The output of the neural network before the softmax :param labels: The ground truth labels in standard (i.e. not one-hot) format :param nlabels: The number of GT labels :param loss_type: Can be 'crossentropy'/'dice'/'crossentropy_and_dice' :return: The segmentation ''' labels = tf.one_hot(labels, depth=nlabels) if loss_type == 'crossentropy': segmentation_loss = losses.pixel_wise_cross_entropy_loss(logits, labels) elif loss_type == 'dice': segmentation_loss = losses.dice_loss(logits, labels) elif loss_type == 'crossentropy_and_dice': segmentation_loss = losses.pixel_wise_cross_entropy_loss(logits, labels) + 0.2*losses.dice_loss(logits, labels) else: raise ValueError('Unknown loss: %s' % loss_type) return segmentation_loss
def loss(logits, labels, nlabels, loss_type, mask_for_loss_within_mask=None, are_labels_1hot=False): ''' Loss to be minimised by the neural network :param logits: The output of the neural network before the softmax :param labels: The ground truth labels in standard (i.e. not one-hot) format :param nlabels: The number of GT labels :param loss_type: Can be 'crossentropy'/'dice'/ :return: The segmentation ''' if are_labels_1hot is False: labels = tf.one_hot(labels, depth=nlabels) if loss_type == 'crossentropy': segmentation_loss = losses.pixel_wise_cross_entropy_loss( logits, labels) elif loss_type == 'crossentropy_reverse': predicted_probabilities = tf.nn.softmax(logits) segmentation_loss = losses.pixel_wise_cross_entropy_loss_using_probs( predicted_probabilities, labels) elif loss_type == 'dice': segmentation_loss = losses.dice_loss(logits, labels) elif loss_type == 'dice_within_mask': if mask_for_loss_within_mask is not None: segmentation_loss = losses.dice_loss_within_mask( logits, labels, mask_for_loss_within_mask) else: raise ValueError('Unknown loss: %s' % loss_type) return segmentation_loss