Пример #1
0
 def test_num_classes_not_equal(self):
     metric = segmentation_metrics.DiceScore(num_classes=4)
     y_pred = tf.constant(0.5,
                          shape=[2, 128, 128, 128, 2],
                          dtype=tf.float32)
     y_true = tf.ones(shape=[2, 128, 128, 128, 2], dtype=tf.float32)
     with self.assertRaisesRegex(
             ValueError,
             'The number of classes from groundtruth labels and `num_classes` '
             'should equal'):
         metric.update_state(y_true=y_true, y_pred=y_pred)
Пример #2
0
 def test_forward_dice_score(self, num_classes, metric_type, output,
                             expected_score):
     metric = segmentation_metrics.DiceScore(num_classes=num_classes,
                                             metric_type=metric_type,
                                             per_class_metric=True)
     y_pred = tf.constant(output,
                          shape=[2, 128, 128, 128, num_classes],
                          dtype=tf.float32)
     y_true = tf.ones(shape=[2, 128, 128, 128, num_classes],
                      dtype=tf.float32)
     metric.update_state(y_true=y_true, y_pred=y_pred)
     actual_score = metric.result().numpy()
     self.assertAllClose(
         actual_score,
         expected_score,
         atol=1e-2,
         msg='Output metric {} does not match expected metric {}.'.format(
             actual_score, expected_score))
Пример #3
0
  def build_metrics(self,
                    training: bool = True) -> Sequence[tf.keras.metrics.Metric]:
    """Gets streaming metrics for training/validation."""
    metrics = []
    num_classes = self.task_config.model.num_classes
    if training:
      metrics.extend([
          tf.keras.metrics.CategoricalAccuracy(
              name='train_categorical_accuracy', dtype=tf.float32)
      ])
    else:
      self.metrics = [
          segmentation_metrics.DiceScore(
              num_classes=num_classes,
              metric_type='generalized',
              per_class_metric=self.task_config.evaluation
              .report_per_class_metric,
              name='val_generalized_dice',
              dtype=tf.float32)
      ]

    return metrics