示例#1
0
def default_binary_classification_specs(
        model_names: Optional[List[Text]] = None,
        output_names: Optional[List[Text]] = None,
        output_weights: Optional[Dict[Text, float]] = None,
        binarize: Optional[config.BinarizationOptions] = None,
        aggregate: Optional[config.AggregationOptions] = None,
        include_loss: bool = True) -> List[config.MetricsSpec]:
    """Returns default metric specs for binary classification problems.

  Args:
    model_names: Optional model names (if multi-model evaluation).
    output_names: Optional list of output names (if multi-output model).
    output_weights: Optional output weights for creating overall metric
      aggregated across outputs (if multi-output model). If a weight is not
      provided for an output, it's weight defaults to 0.0 (i.e. output ignored).
    binarize: Optional settings for binarizing multi-class/multi-label metrics.
    aggregate: Optional settings for aggregating multi-class/multi-label
      metrics.
    include_loss: True to include loss.
  """

    metrics = [
        tf.keras.metrics.BinaryAccuracy(name='accuracy'),
        tf.keras.metrics.AUC(
            name='auc',
            num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS),
        tf.keras.metrics.AUC(
            name=
            'auc_precison_recall',  # Matches default name used by estimator.
            curve='PR',
            num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        calibration.MeanLabel(name='mean_label'),
        calibration.MeanPrediction(name='mean_prediction'),
        calibration.Calibration(name='calibration'),
        confusion_matrix_plot.ConfusionMatrixPlot(
            name='confusion_matrix_plot'),
        calibration_plot.CalibrationPlot(name='calibration_plot')
    ]
    if include_loss:
        metrics.append(tf.keras.metrics.BinaryCrossentropy(name='loss'))

    return specs_from_metrics(metrics,
                              model_names=model_names,
                              output_names=output_names,
                              output_weights=output_weights,
                              binarize=binarize,
                              aggregate=aggregate)
    def testConfusionMatrixPlot(self):
        computations = confusion_matrix_plot.ConfusionMatrixPlot(
            num_thresholds=4).computations()
        histogram = computations[0]
        matrices = computations[1]
        plot = computations[2]

        example1 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.0]),
            'example_weights': np.array([1.0]),
        }
        example2 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.5]),
            'example_weights': np.array([1.0]),
        }
        example3 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.3]),
            'example_weights': np.array([1.0]),
        }
        example4 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.9]),
            'example_weights': np.array([1.0]),
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create(
                    [example1, example2, example3, example4])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'ComputeHistogram' >> beam.CombinePerKey(histogram.combiner)
                | 'ComputeMatrices' >>
                beam.Map(lambda x:
                         (x[0], matrices.result(x[1])))  # pyformat: ignore
                |
                'ComputePlot' >> beam.Map(lambda x: (x[0], plot.result(x[1]))))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    self.assertLen(got_plots, 1)
                    key = metric_types.PlotKey(name='confusion_matrix_plot')
                    self.assertIn(key, got_plots)
                    got_plot = got_plots[key]
                    self.assertProtoEquals(
                        """
              matrices {
                threshold: -1e-06
                false_positives: 2.0
                true_positives: 2.0
                precision: 0.5
                recall: 1.0
              }
              matrices {
                false_positives: 2.0
                true_positives: 2.0
                precision: 0.5
                recall: 1.0
              }
              matrices {
                threshold: 0.25
                true_negatives: 1.0
                false_positives: 1.0
                true_positives: 2.0
                precision: 0.6666667
                recall: 1.0
              }
              matrices {
                threshold: 0.5
                false_negatives: 1.0
                true_negatives: 2.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
              matrices {
                threshold: 0.75
                false_negatives: 1.0
                true_negatives: 2.0
                true_positives: 1.0
                precision: 1.0
                recall: 0.5
              }
              matrices {
                threshold: 1.0
                false_negatives: 2.0
                true_negatives: 2.0
                precision: 1.0
                recall: 0.0
              }
          """, got_plot)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')