Пример #1
0
  def __init__(self,
               num_classes: int,
               metric_type: Optional[str] = None,
               per_class_metric: bool = False,
               name: Optional[str] = None,
               dtype: Optional[str] = None):
    """Constructs segmentation evaluator class.

    Args:
      num_classes: The number of classes.
      metric_type: An optional `str` of type of dice scores.
      per_class_metric: Whether to report per-class metric.
      name: A `str`, name of the metric instance..
      dtype: The data type of the metric result.
    """
    self._num_classes = num_classes
    self._per_class_metric = per_class_metric
    self._dice_op_overall = segmentation_losses.SegmentationLossDiceScore(
        metric_type=metric_type)
    self._dice_scores_overall = tf.Variable(0.0)
    self._count = tf.Variable(0.0)

    if self._per_class_metric:
      # Always use raw dice score for per-class metrics, so metric_type is None
      # by default.
      self._dice_op_per_class = segmentation_losses.SegmentationLossDiceScore()

      self._dice_scores_per_class = [
          tf.Variable(0.0) for _ in range(num_classes)
      ]
      self._count_per_class = [tf.Variable(0.0) for _ in range(num_classes)]

    self.name = name
    self.dtype = dtype
 def test_supported_loss(self, metric_type, output, expected_score):
     loss = segmentation_losses.SegmentationLossDiceScore(
         metric_type=metric_type)
     logits = tf.constant(output,
                          shape=[2, 128, 128, 128, 1],
                          dtype=tf.float32)
     labels = tf.ones(shape=[2, 128, 128, 128, 1], dtype=tf.float32)
     actual_score = loss(logits=logits, labels=labels)
     self.assertAlmostEqual(actual_score.numpy(), expected_score, places=1)
Пример #3
0
  def build_losses(self,
                   labels: tf.Tensor,
                   model_outputs: tf.Tensor,
                   aux_losses=None) -> tf.Tensor:
    """Segmentation loss.

    Args:
      labels: labels.
      model_outputs: Output logits of the classifier.
      aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.

    Returns:
      The total loss tensor.
    """
    segmentation_loss_fn = segmentation_losses.SegmentationLossDiceScore(
        metric_type='adaptive')

    total_loss = segmentation_loss_fn(model_outputs, labels)

    if aux_losses:
      total_loss += tf.add_n(aux_losses)

    return total_loss