示例#1
0
    def __init__(self,
                 n_class,
                 loss_type='Dice',
                 softmax=True,
                 loss_func_params=None,
                 name='loss_function'):

        super(LossFunction, self).__init__(name=name)
        assert n_class > 0, \
            "Number of classes for segmentation loss should be positive."
        self._num_classes = n_class

        self._softmax = bool(softmax)
        # set loss function and function-specific additional params.
        self._data_loss_func = LossSegmentationFactory.create(loss_type)
        self._loss_func_params = \
            loss_func_params if loss_func_params is not None else dict()

        data_loss_function_name = self._data_loss_func.__name__
        if data_loss_function_name.startswith('cross_entropy') \
                or 'xent' in data_loss_function_name:
            tf.logging.info(
                'Cross entropy loss function calls '
                'tf.nn.sparse_softmax_cross_entropy_with_logits '
                'which always performs a softmax internally.')
            self._softmax = False
示例#2
0
    def __init__(self,
                 n_class,
                 loss_type='Dice',
                 softmax=True,
                 loss_func_params=None,
                 name='loss_function'):

        super(LossFunction, self).__init__(name=name)
        assert n_class > 0, \
            "Number of classes for segmentation loss should be positive."
        self._num_classes = n_class

        self._softmax = bool(softmax)
        # set loss function and function-specific additional params.
        self._data_loss_func = LossSegmentationFactory.create(loss_type)
        self._loss_func_params = \
            loss_func_params if loss_func_params is not None else dict()

        if self._data_loss_func.__name__ == 'cross_entropy':
            tf.logging.info(
                'Cross entropy loss function calls '
                'tf.nn.sparse_softmax_cross_entropy_with_logits '
                'which always performs a softmax internally.')
            self._softmax = False
示例#3
0
 def make_callable_loss_func(self, type_str):
     self._data_loss_func = LossSegmentationFactory.create(type_str)