def __init__(self, n_class, loss_type='Dice', softmax=True, loss_func_params=None, name='loss_function'): super(LossFunction, self).__init__(name=name) assert n_class > 0, \ "Number of classes for segmentation loss should be positive." self._num_classes = n_class self._softmax = bool(softmax) # set loss function and function-specific additional params. self._data_loss_func = LossSegmentationFactory.create(loss_type) self._loss_func_params = \ loss_func_params if loss_func_params is not None else dict() data_loss_function_name = self._data_loss_func.__name__ if data_loss_function_name.startswith('cross_entropy') \ or 'xent' in data_loss_function_name: tf.logging.info( 'Cross entropy loss function calls ' 'tf.nn.sparse_softmax_cross_entropy_with_logits ' 'which always performs a softmax internally.') self._softmax = False
def __init__(self, n_class, loss_type='Dice', softmax=True, loss_func_params=None, name='loss_function'): super(LossFunction, self).__init__(name=name) assert n_class > 0, \ "Number of classes for segmentation loss should be positive." self._num_classes = n_class self._softmax = bool(softmax) # set loss function and function-specific additional params. self._data_loss_func = LossSegmentationFactory.create(loss_type) self._loss_func_params = \ loss_func_params if loss_func_params is not None else dict() if self._data_loss_func.__name__ == 'cross_entropy': tf.logging.info( 'Cross entropy loss function calls ' 'tf.nn.sparse_softmax_cross_entropy_with_logits ' 'which always performs a softmax internally.') self._softmax = False
def make_callable_loss_func(self, type_str): self._data_loss_func = LossSegmentationFactory.create(type_str)