def test_num_classes_not_equal(self): metric = segmentation_metrics.DiceScore(num_classes=4) y_pred = tf.constant(0.5, shape=[2, 128, 128, 128, 2], dtype=tf.float32) y_true = tf.ones(shape=[2, 128, 128, 128, 2], dtype=tf.float32) with self.assertRaisesRegex( ValueError, 'The number of classes from groundtruth labels and `num_classes` ' 'should equal'): metric.update_state(y_true=y_true, y_pred=y_pred)
def test_forward_dice_score(self, num_classes, metric_type, output, expected_score): metric = segmentation_metrics.DiceScore(num_classes=num_classes, metric_type=metric_type, per_class_metric=True) y_pred = tf.constant(output, shape=[2, 128, 128, 128, num_classes], dtype=tf.float32) y_true = tf.ones(shape=[2, 128, 128, 128, num_classes], dtype=tf.float32) metric.update_state(y_true=y_true, y_pred=y_pred) actual_score = metric.result().numpy() self.assertAllClose( actual_score, expected_score, atol=1e-2, msg='Output metric {} does not match expected metric {}.'.format( actual_score, expected_score))
def build_metrics(self, training: bool = True) -> Sequence[tf.keras.metrics.Metric]: """Gets streaming metrics for training/validation.""" metrics = [] num_classes = self.task_config.model.num_classes if training: metrics.extend([ tf.keras.metrics.CategoricalAccuracy( name='train_categorical_accuracy', dtype=tf.float32) ]) else: self.metrics = [ segmentation_metrics.DiceScore( num_classes=num_classes, metric_type='generalized', per_class_metric=self.task_config.evaluation .report_per_class_metric, name='val_generalized_dice', dtype=tf.float32) ] return metrics