Example #1
0
def default_regression_specs(
    model_names: Optional[List[Text]] = None,
    output_names: Optional[List[Text]] = None,
    loss_functions: Optional[List[Union[tf.keras.metrics.Metric,
                                        tf.keras.losses.Loss]]] = None,
    min_value: Optional[float] = None,
    max_value: Optional[float] = None) -> List[config.MetricsSpec]:
  """Returns default metric specs for for regression problems.

  Args:
    model_names: Optional model names (if multi-model evaluation).
    output_names: Optional list of output names (if multi-output model).
    loss_functions: Loss functions to use (if None MSE is used).
    min_value: Min value for calibration plot (if None no plot will be created).
    max_value: Max value for calibration plot (if None no plot will be created).
  """

  if loss_functions is None:
    loss_functions = [tf.keras.metrics.MeanSquaredError(name='mse')]

  metrics = [
      tf.keras.metrics.Accuracy(name='accuracy'),
      calibration.MeanLabel(name='mean_label'),
      calibration.MeanPrediction(name='mean_prediction'),
      calibration.Calibration(name='calibration'),
  ]
  for fn in loss_functions:
    metrics.append(fn)
  if min_value is not None and max_value is not None:
    metrics.append(
        calibration_plot.CalibrationPlot(
            name='calibration_plot', left=min_value, right=max_value))

  return specs_from_metrics(
      metrics, model_names=model_names, output_names=output_names)
Example #2
0
  def testToComputations(self):
    computations = metric_specs.to_computations(
        metric_specs.specs_from_metrics(
            {
                'output_name': [
                    tf.keras.metrics.MeanSquaredError('mse'),
                    calibration.MeanLabel('mean_label')
                ]
            },
            model_names=['model_name'],
            binarize=config.BinarizationOptions(class_ids={'values': [0, 1]}),
            aggregate=config.AggregationOptions(macro_average=True)),
        config.EvalConfig())

    keys = []
    for m in computations:
      for k in m.keys:
        if not k.name.startswith('_'):
          keys.append(k)
    self.assertLen(keys, 8)
    self.assertIn(metric_types.MetricKey(name='example_count'), keys)
    self.assertIn(
        metric_types.MetricKey(
            name='weighted_example_count',
            model_name='model_name',
            output_name='output_name'), keys)
    self.assertIn(
        metric_types.MetricKey(
            name='mse',
            model_name='model_name',
            output_name='output_name',
            sub_key=metric_types.SubKey(class_id=0)), keys)
    self.assertIn(
        metric_types.MetricKey(
            name='mse',
            model_name='model_name',
            output_name='output_name',
            sub_key=metric_types.SubKey(class_id=1)), keys)
    self.assertIn(
        metric_types.MetricKey(
            name='mse', model_name='model_name', output_name='output_name'),
        keys)
    self.assertIn(
        metric_types.MetricKey(
            name='mean_label',
            model_name='model_name',
            output_name='output_name',
            sub_key=metric_types.SubKey(class_id=0)), keys)
    self.assertIn(
        metric_types.MetricKey(
            name='mean_label',
            model_name='model_name',
            output_name='output_name',
            sub_key=metric_types.SubKey(class_id=1)), keys)
    self.assertIn(
        metric_types.MetricKey(
            name='mean_label',
            model_name='model_name',
            output_name='output_name'), keys)
Example #3
0
def default_binary_classification_specs(
        model_names: Optional[List[Text]] = None,
        output_names: Optional[List[Text]] = None,
        output_weights: Optional[Dict[Text, float]] = None,
        binarize: Optional[config.BinarizationOptions] = None,
        aggregate: Optional[config.AggregationOptions] = None,
        include_loss: bool = True) -> List[config.MetricsSpec]:
    """Returns default metric specs for binary classification problems.

  Args:
    model_names: Optional model names (if multi-model evaluation).
    output_names: Optional list of output names (if multi-output model).
    output_weights: Optional output weights for creating overall metric
      aggregated across outputs (if multi-output model). If a weight is not
      provided for an output, it's weight defaults to 0.0 (i.e. output ignored).
    binarize: Optional settings for binarizing multi-class/multi-label metrics.
    aggregate: Optional settings for aggregating multi-class/multi-label
      metrics.
    include_loss: True to include loss.
  """

    metrics = [
        tf.keras.metrics.BinaryAccuracy(name='accuracy'),
        tf.keras.metrics.AUC(
            name='auc',
            num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS),
        tf.keras.metrics.AUC(
            name=
            'auc_precison_recall',  # Matches default name used by estimator.
            curve='PR',
            num_thresholds=binary_confusion_matrices.DEFAULT_NUM_THRESHOLDS),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        calibration.MeanLabel(name='mean_label'),
        calibration.MeanPrediction(name='mean_prediction'),
        calibration.Calibration(name='calibration'),
        confusion_matrix_plot.ConfusionMatrixPlot(
            name='confusion_matrix_plot'),
        calibration_plot.CalibrationPlot(name='calibration_plot')
    ]
    if include_loss:
        metrics.append(tf.keras.metrics.BinaryCrossentropy(name='loss'))

    return specs_from_metrics(metrics,
                              model_names=model_names,
                              output_names=output_names,
                              output_weights=output_weights,
                              binarize=binarize,
                              aggregate=aggregate)
def default_binary_classification_specs(
        model_names: Optional[List[Text]] = None,
        output_names: Optional[List[Text]] = None,
        class_ids: Optional[List[int]] = None,
        k_list: Optional[List[int]] = None,
        top_k_list: Optional[List[int]] = None,
        include_loss: bool = True) -> List[config.MetricsSpec]:
    """Returns default metric specs for binary classification problems.

  Args:
    model_names: Optional model names (if multi-model evaluation).
    output_names: Optional list of output names (if multi-output model).
    class_ids: Optional class IDs to compute metrics for particular classes in a
      multi-class model. If output_names are provided, all outputs are assumed
      to use the same class IDs.
    k_list: Optional list of k values to compute metrics for the kth predicted
      values of a multi-class model prediction. If output_names are provided,
      all outputs are assumed to use the same k value.
    top_k_list: Optional list of top_k values to compute metrics for the top k
      predicted values in a multi-class model prediction. If output_names are
      provided, all outputs are assumed to use the same top_k value. Metrics and
      plots will be based on treating each predicted value in the top_k as
      though they were separate predictions.
    include_loss: True to include loss.
  """

    metrics = [
        tf.keras.metrics.BinaryAccuracy(name='accuracy'),
        tf.keras.metrics.AUC(name='auc'),
        tf.keras.metrics.AUC(name='auc_pr', curve='PR'),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        calibration.MeanLabel(name='mean_label'),
        calibration.MeanPrediction(name='mean_prediction'),
        calibration.Calibration(name='calibration'),
        auc_plot.AUCPlot(name='auc_plot'),
        calibration_plot.CalibrationPlot(name='calibration_plot')
    ]
    if include_loss:
        metrics.append(tf.keras.metrics.BinaryCrossentropy(name='loss'))

    return specs_from_metrics(metrics,
                              model_names=model_names,
                              output_names=output_names,
                              class_ids=class_ids,
                              k_list=k_list,
                              top_k_list=top_k_list)
def default_binary_classification_specs(
        model_names: Optional[List[Text]] = None,
        output_names: Optional[List[Text]] = None,
        binarize: Optional[config.BinarizationOptions] = None,
        aggregate: Optional[config.AggregationOptions] = None,
        include_loss: bool = True) -> List[config.MetricsSpec]:
    """Returns default metric specs for binary classification problems.

  Args:
    model_names: Optional model names (if multi-model evaluation).
    output_names: Optional list of output names (if multi-output model).
    binarize: Optional settings for binarizing multi-class/multi-label metrics.
    aggregate: Optional settings for aggregating multi-class/multi-label
      metrics.
    include_loss: True to include loss.
  """

    metrics = [
        tf.keras.metrics.BinaryAccuracy(name='accuracy'),
        tf.keras.metrics.AUC(name='auc'),
        tf.keras.metrics.AUC(name='auc_pr', curve='PR'),
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        calibration.MeanLabel(name='mean_label'),
        calibration.MeanPrediction(name='mean_prediction'),
        calibration.Calibration(name='calibration'),
        auc_plot.AUCPlot(name='auc_plot'),
        calibration_plot.CalibrationPlot(name='calibration_plot')
    ]
    if include_loss:
        metrics.append(tf.keras.metrics.BinaryCrossentropy(name='loss'))

    return specs_from_metrics(metrics,
                              model_names=model_names,
                              output_names=output_names,
                              binarize=binarize,
                              aggregate=aggregate)
Example #6
0
    def testSpecsFromMetrics(self):
        metrics_specs = metric_specs.specs_from_metrics(
            {
                'output_name1': [
                    tf.keras.metrics.MeanSquaredError('mse'),
                    tf.keras.losses.MeanAbsoluteError(name='mae'),
                    calibration.MeanLabel('mean_label')
                ],
                'output_name2': [
                    tf.keras.metrics.RootMeanSquaredError('rmse'),
                    tf.keras.losses.MeanAbsolutePercentageError(name='mape'),
                    calibration.MeanPrediction('mean_prediction')
                ]
            },
            model_names=['model_name1', 'model_name2'],
            binarize=config.BinarizationOptions(class_ids={'values': [0, 1]}),
            aggregate=config.AggregationOptions(macro_average=True))

        self.assertLen(metrics_specs, 5)
        self.assertProtoEquals(
            metrics_specs[0],
            config.MetricsSpec(metrics=[
                config.MetricConfig(class_name='ExampleCount',
                                    config=json.dumps(
                                        {'name': 'example_count'})),
            ]))
        self.assertProtoEquals(
            metrics_specs[1],
            config.MetricsSpec(metrics=[
                config.MetricConfig(class_name='WeightedExampleCount',
                                    config=json.dumps(
                                        {'name': 'weighted_example_count'})),
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1']))
        self.assertProtoEquals(
            metrics_specs[2],
            config.MetricsSpec(metrics=[
                config.MetricConfig(class_name='MeanSquaredError',
                                    config=json.dumps(
                                        {
                                            'name': 'mse',
                                            'dtype': 'float32'
                                        },
                                        sort_keys=True)),
                config.MetricConfig(class_name='MeanAbsoluteError',
                                    module=metric_specs._TF_LOSSES_MODULE,
                                    config=json.dumps(
                                        {
                                            'reduction': 'auto',
                                            'name': 'mae'
                                        },
                                        sort_keys=True)),
                config.MetricConfig(class_name='MeanLabel',
                                    config=json.dumps({'name': 'mean_label'}))
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1'],
                               binarize=config.BinarizationOptions(
                                   class_ids={'values': [0, 1]}),
                               aggregate=config.AggregationOptions(
                                   macro_average=True)))
        self.assertProtoEquals(
            metrics_specs[3],
            config.MetricsSpec(metrics=[
                config.MetricConfig(class_name='WeightedExampleCount',
                                    config=json.dumps(
                                        {'name': 'weighted_example_count'})),
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name2']))
        self.assertProtoEquals(
            metrics_specs[4],
            config.MetricsSpec(
                metrics=[
                    config.MetricConfig(class_name='RootMeanSquaredError',
                                        config=json.dumps(
                                            {
                                                'name': 'rmse',
                                                'dtype': 'float32'
                                            },
                                            sort_keys=True)),
                    config.MetricConfig(
                        class_name='MeanAbsolutePercentageError',
                        module=metric_specs._TF_LOSSES_MODULE,
                        config=json.dumps({
                            'reduction': 'auto',
                            'name': 'mape'
                        },
                                          sort_keys=True)),
                    config.MetricConfig(class_name='MeanPrediction',
                                        config=json.dumps(
                                            {'name': 'mean_prediction'}))
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name2'],
                binarize=config.BinarizationOptions(
                    class_ids={'values': [0, 1]}),
                aggregate=config.AggregationOptions(macro_average=True)))
Example #7
0
class CalibrationMetricsTest(testutil.TensorflowModelAnalysisTest,
                             parameterized.TestCase):
    @parameterized.named_parameters(
        ('mean_label', calibration.MeanLabel(), 2.0 / 3.0),
        ('mean_prediction', calibration.MeanPrediction(), (0.3 + 0.9) / 3.0),
        ('calibration', calibration.Calibration(), (0.3 + 0.9) / 2.0))
    def testCalibrationMetricsWithoutWeights(self, metric, expected_value):
        computations = metric.computations()
        weighted_totals = computations[0]
        metric = computations[1]

        example1 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.0]),
            'example_weights': np.array([1.0]),
        }
        example2 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.3]),
            'example_weights': np.array([1.0]),
        }
        example3 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.9]),
            'example_weights': None,  # defaults to 1.0
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([example1, example2, example3])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'ComputeWeightedTotals' >> beam.CombinePerKey(
                    weighted_totals.combiner)
                | 'ComputeMetric' >> beam.Map(lambda x:
                                              (x[0], metric.result(x[1]))))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    key = metric.keys[0]
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {key: expected_value},
                                                       places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')

    @parameterized.named_parameters(
        ('mean_label', calibration.MeanLabel(), 1.0 * 0.7 / 2.1),
        ('mean_prediction', calibration.MeanPrediction(),
         (1.0 * 0.5 + 0.7 * 0.7 + 0.5 * 0.9) / 2.1),
        ('calibration', calibration.Calibration(),
         (1.0 * 0.5 + 0.7 * 0.7 + 0.5 * 0.9) / (1.0 * 0.7)))
    def testCalibrationMetricsWithWeights(self, metric, expected_value):
        computations = metric.computations()
        weighted_totals = computations[0]
        metric = computations[1]

        example1 = {
            'labels': np.array([0.0]),
            'predictions': np.array([1.0]),
            'example_weights': np.array([0.5]),
        }
        example2 = {
            'labels': np.array([1.0]),
            'predictions': np.array([0.7]),
            'example_weights': np.array([0.7]),
        }
        example3 = {
            'labels': np.array([0.0]),
            'predictions': np.array([0.5]),
            'example_weights': np.array([0.9]),
        }

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            result = (
                pipeline
                | 'Create' >> beam.Create([example1, example2, example3])
                | 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSlice' >> beam.Map(lambda x: ((), x))
                | 'ComputeWeightedTotals' >> beam.CombinePerKey(
                    weighted_totals.combiner)
                | 'ComputeMetric' >> beam.Map(lambda x:
                                              (x[0], metric.result(x[1]))))

            # pylint: enable=no-value-for-parameter

            def check_result(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    key = metric.keys[0]
                    self.assertDictElementsAlmostEqual(got_metrics,
                                                       {key: expected_value},
                                                       places=5)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(result, check_result, label='result')
    def testSpecsFromMetrics(self):
        metrics_specs = metric_specs.specs_from_metrics(
            {
                'output_name1': [
                    tf.keras.metrics.MeanSquaredError('mse'),
                    calibration.MeanLabel('mean_label')
                ],
                'output_name2': [
                    tf.keras.metrics.RootMeanSquaredError('rmse'),
                    calibration.MeanPrediction('mean_prediction')
                ]
            },
            model_names=['model_name1', 'model_name2'],
            binarize=config.BinarizationOptions(class_ids=[0, 1]),
            aggregate=config.AggregationOptions(macro_average=True))

        self.assertLen(metrics_specs, 5)
        self.assertProtoEquals(
            metrics_specs[0],
            config.MetricsSpec(metrics=[
                config.MetricConfig(class_name='ExampleCount',
                                    config=json.dumps(
                                        {'name': 'example_count'})),
            ]))
        self.assertProtoEquals(
            metrics_specs[1],
            config.MetricsSpec(metrics=[
                config.MetricConfig(class_name='WeightedExampleCount',
                                    config=json.dumps(
                                        {'name': 'weighted_example_count'})),
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name1']))
        self.assertProtoEquals(
            metrics_specs[2],
            config.MetricsSpec(
                metrics=[
                    config.MetricConfig(class_name='MeanSquaredError',
                                        config=json.dumps({
                                            'name': 'mse',
                                            'dtype': 'float32'
                                        })),
                    config.MetricConfig(class_name='MeanLabel',
                                        config=json.dumps(
                                            {'name': 'mean_label'}))
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name1'],
                binarize=config.BinarizationOptions(class_ids=[0, 1]),
                aggregate=config.AggregationOptions(macro_average=True)))
        self.assertProtoEquals(
            metrics_specs[3],
            config.MetricsSpec(metrics=[
                config.MetricConfig(class_name='WeightedExampleCount',
                                    config=json.dumps(
                                        {'name': 'weighted_example_count'})),
            ],
                               model_names=['model_name1', 'model_name2'],
                               output_names=['output_name2']))
        self.assertProtoEquals(
            metrics_specs[4],
            config.MetricsSpec(
                metrics=[
                    config.MetricConfig(class_name='RootMeanSquaredError',
                                        config=json.dumps({
                                            'name': 'rmse',
                                            'dtype': 'float32'
                                        })),
                    config.MetricConfig(class_name='MeanPrediction',
                                        config=json.dumps(
                                            {'name': 'mean_prediction'}))
                ],
                model_names=['model_name1', 'model_name2'],
                output_names=['output_name2'],
                binarize=config.BinarizationOptions(class_ids=[0, 1]),
                aggregate=config.AggregationOptions(macro_average=True)))
Example #9
0
    def testToComputations(self):
        computations = metric_specs.to_computations(
            metric_specs.specs_from_metrics(
                {
                    'output_name': [
                        tf.keras.metrics.MeanSquaredError('mse'),
                        # Add a loss exactly same as metric
                        # (https://github.com/tensorflow/tfx/issues/1550)
                        tf.keras.losses.MeanSquaredError(name='loss'),
                        calibration.MeanLabel('mean_label')
                    ]
                },
                model_names=['model_name'],
                binarize=config.BinarizationOptions(
                    class_ids={'values': [0, 1]}),
                aggregate=config.AggregationOptions(macro_average=True,
                                                    class_weights={
                                                        0: 1.0,
                                                        1: 1.0
                                                    })),
            config.EvalConfig())

        keys = []
        for m in computations:
            for k in m.keys:
                if not k.name.startswith('_'):
                    keys.append(k)
        self.assertLen(keys, 11)
        self.assertIn(
            metric_types.MetricKey(name='example_count',
                                   model_name='model_name'), keys)
        self.assertIn(
            metric_types.MetricKey(name='weighted_example_count',
                                   model_name='model_name',
                                   output_name='output_name'), keys)
        self.assertIn(
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0)),
            keys)
        self.assertIn(
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1)),
            keys)
        self.assertIn(
            metric_types.MetricKey(name='mse',
                                   model_name='model_name',
                                   output_name='output_name'), keys)
        self.assertIn(
            metric_types.MetricKey(name='loss',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0)),
            keys)
        self.assertIn(
            metric_types.MetricKey(name='loss',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1)),
            keys)
        self.assertIn(
            metric_types.MetricKey(name='loss',
                                   model_name='model_name',
                                   output_name='output_name'), keys)
        self.assertIn(
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=0)),
            keys)
        self.assertIn(
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name',
                                   sub_key=metric_types.SubKey(class_id=1)),
            keys)
        self.assertIn(
            metric_types.MetricKey(name='mean_label',
                                   model_name='model_name',
                                   output_name='output_name'), keys)
    def testEvaluateWithKerasModel(self):
        input1 = tf.keras.layers.Input(shape=(1, ), name='input1')
        input2 = tf.keras.layers.Input(shape=(1, ), name='input2')
        inputs = [input1, input2]
        input_layer = tf.keras.layers.concatenate(inputs)
        output_layer = tf.keras.layers.Dense(1,
                                             activation=tf.nn.sigmoid,
                                             name='output')(input_layer)
        model = tf.keras.models.Model(inputs, output_layer)
        model.compile(optimizer=tf.keras.optimizers.Adam(lr=.001),
                      loss=tf.keras.losses.binary_crossentropy,
                      metrics=['accuracy'])

        features = {'input1': [[0.0], [1.0]], 'input2': [[1.0], [0.0]]}
        labels = [[1], [0]]
        example_weights = [1.0, 0.5]
        dataset = tf.data.Dataset.from_tensor_slices(
            (features, labels, example_weights))
        dataset = dataset.shuffle(buffer_size=1).repeat().batch(2)
        model.fit(dataset, steps_per_epoch=1)

        export_dir = self._getExportDir()
        model.save(export_dir, save_format='tf')

        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='example_weight')
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics(
                [calibration.MeanLabel('mean_label')]))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(input1=0.0,
                              input2=1.0,
                              label=1.0,
                              example_weight=1.0,
                              extra_feature='non_model_feature'),
            self._makeExample(input1=1.0,
                              input2=0.0,
                              label=0.0,
                              example_weight=0.5,
                              extra_feature='non_model_feature'),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key: 2,
                            weighted_example_count_key: (1.0 + 0.5),
                            label_key: (1.0 * 1.0 + 0.0 * 0.5) / (1.0 + 0.5),
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
    def testEvaluateWithMultiOutputModel(self):
        temp_export_dir = self._getExportDir()
        _, export_dir = multi_head.simple_multi_head(None, temp_export_dir)

        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_keys={
                                     'chinese_head': 'chinese_label',
                                     'english_head': 'english_label',
                                     'other_head': 'other_label'
                                 },
                                 example_weight_keys={
                                     'chinese_head': 'age',
                                     'english_head': 'age',
                                     'other_head': 'age'
                                 })
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics({
                'chinese_head': [calibration.MeanLabel('mean_label')],
                'english_head': [calibration.MeanLabel('mean_label')],
                'other_head': [calibration.MeanLabel('mean_label')],
            }))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(age=1.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=1.0,
                              language='chinese',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='english',
                              english_label=1.0,
                              chinese_label=0.0,
                              other_label=0.0),
            self._makeExample(age=2.0,
                              language='other',
                              english_label=0.0,
                              chinese_label=1.0,
                              other_label=1.0),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    chinese_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='chinese_head')
                    chinese_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='chinese_head')
                    english_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='english_head')
                    english_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='english_head')
                    other_weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count',
                        output_name='other_head')
                    other_label_key = metric_types.MetricKey(
                        name='mean_label', output_name='other_head')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key:
                            4,
                            chinese_label_key:
                            (0.0 + 1.0 + 2 * 0.0 + 2 * 1.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            chinese_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0),
                            english_label_key:
                            (1.0 + 0.0 + 2 * 1.0 + 2 * 0.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            english_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0),
                            other_label_key: (0.0 + 0.0 + 2 * 0.0 + 2 * 1.0) /
                            (1.0 + 1.0 + 2.0 + 2.0),
                            other_weighted_example_count_key:
                            (1.0 + 1.0 + 2.0 + 2.0)
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
    def testEvaluateWithSlicing(self):
        temp_export_dir = self._getExportDir()
        _, export_dir = (fixed_prediction_estimator_extra_fields.
                         simple_fixed_prediction_estimator_extra_fields(
                             None, temp_export_dir))
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='fixed_float')
            ],
            slicing_specs=[
                config.SlicingSpec(),
                config.SlicingSpec(feature_keys=['fixed_string']),
            ],
            metrics_specs=metric_specs.specs_from_metrics([
                calibration.MeanLabel('mean_label'),
                calibration.MeanPrediction('mean_prediction')
            ]))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir)
        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            predict_extractor.PredictExtractor(
                eval_shared_model=eval_shared_model),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        # fixed_float used as example_weight key
        examples = [
            self._makeExample(prediction=0.2,
                              label=1.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.8,
                              label=0.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.5,
                              label=0.0,
                              fixed_int=2,
                              fixed_float=2.0,
                              fixed_string='fixed_string2')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 3)
                    slices = {}
                    for slice_key, value in got:
                        slices[slice_key] = value
                    overall_slice = ()
                    fixed_string1_slice = (('fixed_string',
                                            b'fixed_string1'), )
                    fixed_string2_slice = (('fixed_string',
                                            b'fixed_string2'), )
                    self.asssertCountEqual(list(slices.keys()), [
                        overall_slice, fixed_string1_slice, fixed_string2_slice
                    ])
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    pred_key = metric_types.MetricKey(name='mean_prediction')
                    self.assertDictElementsAlmostEqual(
                        slices[overall_slice], {
                            example_count_key: 3,
                            weighted_example_count_key: 4.0,
                            label_key:
                            (1.0 + 0.0 + 2 * 0.0) / (1.0 + 1.0 + 2.0),
                            pred_key:
                            (0.2 + 0.8 + 2 * 0.5) / (1.0 + 1.0 + 2.0),
                        })
                    self.assertDictElementsAlmostEqual(
                        slices[fixed_string1_slice], {
                            example_count_key: 2,
                            weighted_example_count_key: 2.0,
                            label_key: (1.0 + 0.0) / (1.0 + 1.0),
                            pred_key: (0.2 + 0.8) / (1.0 + 1.0),
                        })
                    self.assertDictElementsAlmostEqual(
                        slices[fixed_string2_slice], {
                            example_count_key: 1,
                            weighted_example_count_key: 2.0,
                            label_key: (2 * 0.0) / 2.0,
                            pred_key: (2 * 0.5) / 2.0,
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

                util.assert_that(metrics[constants.METRICS_KEY],
                                 check_metrics,
                                 label='metrics')
    def testEvaluateWithMultiClassModel(self):
        n_classes = 3
        temp_export_dir = self._getExportDir()
        _, export_dir = dnn_classifier.simple_dnn_classifier(
            None, temp_export_dir, n_classes=n_classes)

        # Add example_count and weighted_example_count
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='age')
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics(
                [calibration.MeanLabel('mean_label')],
                binarize=config.BinarizationOptions(
                    class_ids=range(n_classes))))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(age=1.0, language='english', label=0),
            self._makeExample(age=2.0, language='chinese', label=1),
            self._makeExample(age=3.0, language='english', label=2),
            self._makeExample(age=4.0, language='chinese', label=1),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key_class_0 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=0))
                    label_key_class_1 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=1))
                    label_key_class_2 = metric_types.MetricKey(
                        name='mean_label',
                        sub_key=metric_types.SubKey(class_id=2))
                    self.assertEqual(got_slice_key, ())
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key:
                            4,
                            weighted_example_count_key:
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_0:
                            (1 * 1.0 + 0 * 2.0 + 0 * 3.0 + 0 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_1:
                            (0 * 1.0 + 1 * 2.0 + 0 * 3.0 + 1 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0),
                            label_key_class_2:
                            (0 * 1.0 + 0 * 2.0 + 1 * 3.0 + 0 * 4.0) /
                            (1.0 + 2.0 + 3.0 + 4.0)
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
    def testEvaluateWithBinaryClassificationModel(self):
        n_classes = 2
        temp_export_dir = self._getExportDir()
        _, export_dir = dnn_classifier.simple_dnn_classifier(
            None, temp_export_dir, n_classes=n_classes)

        # Add mean_label, example_count, weighted_example_count, calibration_plot
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(location=export_dir,
                                 label_key='label',
                                 example_weight_key='age')
            ],
            slicing_specs=[config.SlicingSpec()],
            metrics_specs=metric_specs.specs_from_metrics([
                calibration.MeanLabel('mean_label'),
                calibration_plot.CalibrationPlot(name='calibration_plot',
                                                 num_buckets=10)
            ]))
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model]),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config,
                eval_shared_models=[eval_shared_model])
        ]

        examples = [
            self._makeExample(age=1.0, language='english', label=0.0),
            self._makeExample(age=2.0, language='chinese', label=1.0),
            self._makeExample(age=3.0, language='chinese', label=0.0),
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics_and_plots = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_metrics = got[0]
                    self.assertEqual(got_slice_key, ())
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    self.assertDictElementsAlmostEqual(
                        got_metrics, {
                            example_count_key:
                            3,
                            weighted_example_count_key: (1.0 + 2.0 + 3.0),
                            label_key:
                            (0 * 1.0 + 1 * 2.0 + 0 * 3.0) / (1.0 + 2.0 + 3.0),
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            def check_plots(got):
                try:
                    self.assertLen(got, 1)
                    got_slice_key, got_plots = got[0]
                    self.assertEqual(got_slice_key, ())
                    plot_key = metric_types.PlotKey('calibration_plot')
                    self.assertIn(plot_key, got_plots)
                    # 10 buckets + 2 for edge cases
                    self.assertLen(got_plots[plot_key].buckets, 12)

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics_and_plots[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')
            util.assert_that(metrics_and_plots[constants.PLOTS_KEY],
                             check_plots,
                             label='plots')
    def testEvaluateWithConfidenceIntervals(self):
        # NOTE: This test does not actually test that confidence intervals are
        #   accurate it only tests that the proto output by the test is well formed.
        #   This test would pass if the confidence interval implementation did
        #   nothing at all except compute the unsampled value.
        temp_export_dir = self._getExportDir()
        _, export_dir = (fixed_prediction_estimator_extra_fields.
                         simple_fixed_prediction_estimator_extra_fields(
                             None, temp_export_dir))
        options = config.Options()
        options.compute_confidence_intervals.value = True
        eval_config = config.EvalConfig(
            model_specs=[
                config.ModelSpec(label_key='label',
                                 example_weight_key='fixed_float')
            ],
            slicing_specs=[
                config.SlicingSpec(),
                config.SlicingSpec(feature_keys=['fixed_string']),
            ],
            metrics_specs=metric_specs.specs_from_metrics([
                calibration.MeanLabel('mean_label'),
                calibration.MeanPrediction('mean_prediction')
            ]),
            options=options)
        eval_shared_model = self.createTestEvalSharedModel(
            eval_saved_model_path=export_dir, tags=[tf.saved_model.SERVING])

        slice_spec = [
            slicer.SingleSliceSpec(spec=s) for s in eval_config.slicing_specs
        ]
        extractors = [
            input_extractor.InputExtractor(eval_config=eval_config),
            predict_extractor_v2.PredictExtractor(
                eval_config=eval_config, eval_shared_model=eval_shared_model),
            slice_key_extractor.SliceKeyExtractor(slice_spec=slice_spec)
        ]
        evaluators = [
            metrics_and_plots_evaluator_v2.MetricsAndPlotsEvaluator(
                eval_config=eval_config, eval_shared_model=eval_shared_model)
        ]

        # fixed_float used as example_weight key
        examples = [
            self._makeExample(prediction=0.2,
                              label=1.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.8,
                              label=0.0,
                              fixed_int=1,
                              fixed_float=1.0,
                              fixed_string='fixed_string1'),
            self._makeExample(prediction=0.5,
                              label=0.0,
                              fixed_int=2,
                              fixed_float=2.0,
                              fixed_string='fixed_string2')
        ]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            metrics = (
                pipeline
                | 'Create' >> beam.Create(
                    [e.SerializeToString() for e in examples])
                | 'InputsToExtracts' >> model_eval_lib.InputsToExtracts()
                | 'ExtractAndEvaluate' >> model_eval_lib.ExtractAndEvaluate(
                    extractors=extractors, evaluators=evaluators))

            # pylint: enable=no-value-for-parameter

            def check_metrics(got):
                try:
                    self.assertLen(got, 3)
                    slices = {}
                    for slice_key, value in got:
                        slices[slice_key] = value
                    overall_slice = ()
                    fixed_string1_slice = (('fixed_string',
                                            b'fixed_string1'), )
                    fixed_string2_slice = (('fixed_string',
                                            b'fixed_string2'), )
                    self.assertCountEqual(list(slices.keys()), [
                        overall_slice, fixed_string1_slice, fixed_string2_slice
                    ])
                    example_count_key = metric_types.MetricKey(
                        name='example_count')
                    weighted_example_count_key = metric_types.MetricKey(
                        name='weighted_example_count')
                    label_key = metric_types.MetricKey(name='mean_label')
                    pred_key = metric_types.MetricKey(name='mean_prediction')
                    self.assertDictElementsWithTDistributionAlmostEqual(
                        slices[overall_slice], {
                            example_count_key: 3,
                            weighted_example_count_key: 4.0,
                            label_key:
                            (1.0 + 0.0 + 2 * 0.0) / (1.0 + 1.0 + 2.0),
                            pred_key:
                            (0.2 + 0.8 + 2 * 0.5) / (1.0 + 1.0 + 2.0),
                        })
                    self.assertDictElementsWithTDistributionAlmostEqual(
                        slices[fixed_string1_slice], {
                            example_count_key: 2,
                            weighted_example_count_key: 2.0,
                            label_key: (1.0 + 0.0) / (1.0 + 1.0),
                            pred_key: (0.2 + 0.8) / (1.0 + 1.0),
                        })
                    self.assertDictElementsWithTDistributionAlmostEqual(
                        slices[fixed_string2_slice], {
                            example_count_key: 1,
                            weighted_example_count_key: 2.0,
                            label_key: (2 * 0.0) / 2.0,
                            pred_key: (2 * 0.5) / 2.0,
                        })

                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(metrics[constants.METRICS_KEY],
                             check_metrics,
                             label='metrics')