def default_binary_classification_specs( model_names: Optional[List[Text]] = None, output_names: Optional[List[Text]] = None, output_weights: Optional[Dict[Text, float]] = None, binarize: Optional[config.BinarizationOptions] = None, aggregate: Optional[config.AggregationOptions] = None, include_loss: bool = True) -> List[config.MetricsSpec]: """Returns default metric specs for binary classification problems. Args: model_names: Optional model names (if multi-model evaluation). output_names: Optional list of output names (if multi-output model). output_weights: Optional output weights for creating overall metric aggregated across outputs (if multi-output model). If a weight is not provided for an output, it's weight defaults to 0.0 (i.e. output ignored). binarize: Optional settings for binarizing multi-class/multi-label metrics. aggregate: Optional settings for aggregating multi-class/multi-label metrics. include_loss: True to include loss. """ metrics = [ tf.keras.metrics.BinaryAccuracy(name='accuracy'), tf.keras.metrics.AUC( name='auc', num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS), tf.keras.metrics.AUC( name= 'auc_precison_recall', # Matches default name used by estimator. curve='PR', num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS), tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(name='recall'), calibration.MeanLabel(name='mean_label'), calibration.MeanPrediction(name='mean_prediction'), calibration.Calibration(name='calibration'), confusion_matrix_plot.ConfusionMatrixPlot( name='confusion_matrix_plot'), calibration_plot.CalibrationPlot(name='calibration_plot') ] if include_loss: metrics.append(tf.keras.metrics.BinaryCrossentropy(name='loss')) return specs_from_metrics(metrics, model_names=model_names, output_names=output_names, output_weights=output_weights, binarize=binarize, aggregate=aggregate)
def testConfusionMatrixPlot(self): computations = confusion_matrix_plot.ConfusionMatrixPlot( num_thresholds=4).computations() histogram = computations[0] matrices = computations[1] plot = computations[2] example1 = { 'labels': np.array([0.0]), 'predictions': np.array([0.0]), 'example_weights': np.array([1.0]), } example2 = { 'labels': np.array([0.0]), 'predictions': np.array([0.5]), 'example_weights': np.array([1.0]), } example3 = { 'labels': np.array([1.0]), 'predictions': np.array([0.3]), 'example_weights': np.array([1.0]), } example4 = { 'labels': np.array([1.0]), 'predictions': np.array([0.9]), 'example_weights': np.array([1.0]), } with beam.Pipeline() as pipeline: # pylint: disable=no-value-for-parameter result = ( pipeline | 'Create' >> beam.Create( [example1, example2, example3, example4]) | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs) | 'AddSlice' >> beam.Map(lambda x: ((), x)) | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner) | 'ComputeMatrices' >> beam.Map(lambda x: (x[0], matrices.result(x[1]))) # pyformat: ignore | 'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1])))) # pylint: enable=no-value-for-parameter def check_result(got): try: self.assertLen(got, 1) got_slice_key, got_plots = got[0] self.assertEqual(got_slice_key, ()) self.assertLen(got_plots, 1) key = metric_types.PlotKey(name='confusion_matrix_plot') self.assertIn(key, got_plots) got_plot = got_plots[key] self.assertProtoEquals( """ matrices { threshold: -1e-06 false_positives: 2.0 true_positives: 2.0 precision: 0.5 recall: 1.0 } matrices { false_positives: 2.0 true_positives: 2.0 precision: 0.5 recall: 1.0 } matrices { threshold: 0.25 true_negatives: 1.0 false_positives: 1.0 true_positives: 2.0 precision: 0.6666667 recall: 1.0 } matrices { threshold: 0.5 false_negatives: 1.0 true_negatives: 2.0 true_positives: 1.0 precision: 1.0 recall: 0.5 } matrices { threshold: 0.75 false_negatives: 1.0 true_negatives: 2.0 true_positives: 1.0 precision: 1.0 recall: 0.5 } matrices { threshold: 1.0 false_negatives: 2.0 true_negatives: 2.0 precision: 1.0 recall: 0.0 } """, got_plot) except AssertionError as err: raise util.BeamAssertException(err) util.assert_that(result, check_result, label='result')