def default_multi_class_classification_specs( model_names: Optional[List[Text]] = None, output_names: Optional[List[Text]] = None, output_weights: Optional[Dict[Text, float]] = None, binarize: Optional[config.BinarizationOptions] = None, aggregate: Optional[config.AggregationOptions] = None, sparse: bool = True) -> List[config.MetricsSpec]: """Returns default metric specs for multi-class classification problems. Args: model_names: Optional model names if multi-model evaluation. output_names: Optional list of output names (if multi-output model). output_weights: Optional output weights for creating overall metric aggregated across outputs (if multi-output model). If a weight is not provided for an output, it's weight defaults to 0.0 (i.e. output ignored). binarize: Optional settings for binarizing multi-class/multi-label metrics. aggregate: Optional settings for aggregating multi-class/multi-label metrics. sparse: True if the labels are sparse. """ if sparse: metrics = [ tf.keras.metrics.SparseCategoricalCrossentropy(name='loss'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy') ] else: metrics = [ tf.keras.metrics.CategoricalCrossentropy(name='loss'), tf.keras.metrics.CategoricalAccuracy(name='accuracy') ] metrics.append( multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot()) if binarize is not None: for top_k in binarize.top_k_list.values: metrics.extend([ tf.keras.metrics.Precision(name='precision', top_k=top_k), tf.keras.metrics.Recall(name='recall', top_k=top_k) ]) binarize_without_top_k = config.BinarizationOptions() binarize_without_top_k.CopyFrom(binarize) binarize_without_top_k.ClearField('top_k_list') binarize = binarize_without_top_k multi_class_metrics = specs_from_metrics( metrics, model_names=model_names, output_names=output_names, output_weights=output_weights) if aggregate is None: aggregate = config.AggregationOptions(micro_average=True) multi_class_metrics.extend( default_binary_classification_specs( model_names=model_names, output_names=output_names, output_weights=output_weights, binarize=binarize, aggregate=aggregate)) return multi_class_metrics
def default_multi_class_classification_specs( model_names: Optional[List[Text]] = None, output_names: Optional[List[Text]] = None, binarize: Optional[config.BinarizationOptions] = None, aggregate: Optional[config.AggregationOptions] = None, sparse: bool = True) -> config.MetricsSpec: """Returns default metric specs for multi-class classification problems. Args: model_names: Optional model names if multi-model evaluation. output_names: Optional list of output names (if multi-output model). binarize: Optional settings for binarizing multi-class/multi-label metrics. aggregate: Optional settings for aggregating multi-class/multi-label metrics. sparse: True if the labels are sparse. """ if sparse: metrics = [ tf.keras.metrics.SparseCategoricalCrossentropy(name='loss'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy') ] else: metrics = [ tf.keras.metrics.CategoricalCrossentropy(name='loss'), tf.keras.metrics.CategoricalAccuracy(name='accuracy') ] metrics.append( multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot()) if binarize is not None: for top_k in binarize.top_k_list.values: metrics.extend([ tf.keras.metrics.Precision(name='precision', top_k=top_k), tf.keras.metrics.Recall(name='recall', top_k=top_k) ]) _binarize = config.BinarizationOptions() _binarize.CopyFrom(binarize) _binarize.ClearField('top_k_list') # pytype: disable=attribute-error multi_class_metrics = specs_from_metrics( metrics, model_names=model_names, output_names=output_names) if aggregate is None: aggregate = config.AggregationOptions(micro_average=True) multi_class_metrics.extend( default_binary_classification_specs( model_names=model_names, output_names=output_names, binarize=_binarize, aggregate=aggregate)) return multi_class_metrics # pytype: disable=bad-return-type
def testMultiClassConfusionMatrixPlotWithThresholds(self, kwargs): computation = ( multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot( **kwargs).computations()[0]) example1 = { 'labels': np.array([2.0]), 'predictions': np.array([0.2, 0.35, 0.45]), 'example_weights': np.array([1.0]) } example2 = { 'labels': np.array([0.0]), 'predictions': np.array([0.1, 0.35, 0.55]), 'example_weights': np.array([1.0]) } example3 = { 'labels': np.array([1.0]), 'predictions': np.array([0.3, 0.25, 0.45]), 'example_weights': np.array([1.0]) } example4 = { 'labels': np.array([1.0]), 'predictions': np.array([0.1, 0.9, 0.0]), 'example_weights': np.array([1.0]) } example5 = { 'labels': np.array([1.0]), 'predictions': np.array([0.1, 0.8, 0.1]), 'example_weights': np.array([1.0]) } example6 = { 'labels': np.array([2.0]), 'predictions': np.array([0.3, 0.25, 0.45]), 'example_weights': np.array([1.0]) } with beam.Pipeline() as pipeline: # pylint: disable=no-value-for-parameter result = ( pipeline | 'Create' >> beam.Create([ example1, example2, example3, example4, example5, example6 ]) | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) | 'AddSlice' >> beam.Map(lambda x: ((), x)) | 'ComputePlot' >> beam.CombinePerKey(computation.combiner)) # pylint: enable=no-value-for-parameter def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey( name='multi_class_confusion_matrix_plot') got_matrix = got_plots[key] self.assertProtoEquals( """ matrices { threshold: 0.0 entries { actual_class_id: 0 predicted_class_id: 2 num_weighted_examples: 1.0 } entries { actual_class_id: 1 predicted_class_id: 1 num_weighted_examples: 2.0 } entries { actual_class_id: 1 predicted_class_id: 2 num_weighted_examples: 1.0 } entries { actual_class_id: 2 predicted_class_id: 2 num_weighted_examples: 2.0 } } matrices { threshold: 0.5 entries { actual_class_id: 0 predicted_class_id: 2 num_weighted_examples: 1.0 } entries { actual_class_id: 1 predicted_class_id: -1 num_weighted_examples: 1.0 } entries { actual_class_id: 1 predicted_class_id: 1 num_weighted_examples: 2.0 } entries { actual_class_id: 2 predicted_class_id: -1 num_weighted_examples: 2.0 } } matrices { threshold: 1.0 entries { predicted_class_id: -1 num_weighted_examples: 1.0 } entries { actual_class_id: 1 predicted_class_id: -1 num_weighted_examples: 3.0 } entries { actual_class_id: 2 predicted_class_id: -1 num_weighted_examples: 2.0 } } """, got_matrix) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(result, check_result, label='result')
def testMultiClassConfusionMatrixPlotWithStringLabels(self): computations = ( multi_class_confusion_matrix_plot.MultiClassConfusionMatrixPlot() .computations(example_weighted=True)) matrices = computations[0] plot = computations[1] # Examples from b/149558504. example1 = { 'labels': np.array([['unacc']]), 'predictions': { 'probabilities': np.array([[ 1.0000000e+00, 6.9407083e-24, 2.7419115e-38, 0.0000000e+00 ]]), 'all_classes': np.array([['unacc', 'acc', 'vgood', 'good']]), }, 'example_weights': np.array([0.5]) } example2 = { 'labels': np.array([['vgood']]), 'predictions': { 'probabilities': np.array([[0.2, 0.3, 0.4, 0.1]]), 'all_classes': np.array([['unacc', 'acc', 'vgood', 'good']]), }, 'example_weights': np.array([1.0]) } with beam.Pipeline() as pipeline: # pylint: disable=no-value-for-parameter result = ( pipeline | 'Create' >> beam.Create([example1, example2]) | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) | 'AddSlice' >> beam.Map(lambda x: ((), x)) | 'ComputeMatrices' >> beam.CombinePerKey(matrices.combiner) | 'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1])))) # pylint: enable=no-value-for-parameter def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey( name='multi_class_confusion_matrix_plot', example_weighted=True) got_matrix = got_plots[key] self.assertProtoEquals( """ matrices { threshold: 0.0 entries { actual_class_id: 0 predicted_class_id: 0 num_weighted_examples: 0.5 } entries { actual_class_id: 2 predicted_class_id: 2 num_weighted_examples: 1.0 } } """, got_matrix) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(result, check_result, label='result')