def testMetricComputedBeamCounter(self):
    with beam.Pipeline() as pipeline:
      auc = post_export_metrics.auc()
      _ = pipeline | counter_util.IncrementMetricsCallbacksCounters([auc])

    result = pipeline.run()
    metric_filter = beam.metrics.metric.MetricsFilter().with_namespace(
        constants.METRICS_NAMESPACE).with_name('metric_computed_auc_v1')
    actual_metrics_count = result.metrics().query(
        filter=metric_filter)['counters'][0].committed

    self.assertEqual(actual_metrics_count, 1)
Example #2
0
def ComputeMetricsAndPlots(  # pylint: disable=invalid-name
    extracts: beam.pvalue.PCollection,
    eval_shared_model: types.EvalSharedModel,
    desired_batch_size: Optional[int] = None,
    compute_confidence_intervals: Optional[bool] = False,
    random_seed_for_testing: Optional[int] = None
) -> Tuple[beam.pvalue.DoOutputsTuple, beam.pvalue.PCollection]:
    """Computes metrics and plots using the EvalSavedModel.

  Args:
    extracts: PCollection of Extracts. The extracts MUST contain a
      FeaturesPredictionsLabels extract keyed by
      tfma.FEATURE_PREDICTIONS_LABELS_KEY and a list of SliceKeyType extracts
      keyed by tfma.SLICE_KEY_TYPES_KEY. Typically these will be added by
      calling the default_extractors function.
    eval_shared_model: Shared model parameters for EvalSavedModel including any
      additional metrics (see EvalSharedModel for more information on how to
      configure additional metrics).
    desired_batch_size: Optional batch size for batching in Aggregate.
    compute_confidence_intervals: Set to True to run metrics analysis over
      multiple bootstrap samples and compute uncertainty intervals.
    random_seed_for_testing: Provide for deterministic tests only.

  Returns:
    Tuple of Tuple[PCollection of (slice key, metrics),
    PCollection of (slice key, plot metrics)] and
    PCollection of (slice_key and its example count).
  """
    # pylint: disable=no-value-for-parameter

    slices = (
        extracts
        # Downstream computation only cares about FPLs, so we prune before fanout.
        # Note that fanout itself will prune the slice keys.
        # TODO(b/130032676, b/111353165): Prune FPLs to contain only the necessary
        # set for the calculation of post_export_metrics if possible.
        | 'PruneExtracts' >> extractor.Filter(include=[
            constants.FEATURES_PREDICTIONS_LABELS_KEY,
            constants.SLICE_KEY_TYPES_KEY,
            constants.INPUT_KEY,
        ])
        # Input: one example at a time, with slice keys in extracts.
        # Output: one fpl example per slice key (notice that the example turns
        #         into n logical examples, references to which are replicated once
        #         per applicable slice key).
        | 'FanoutSlices' >> slicer.FanoutSlices())

    slices_count = (slices
                    | 'ExtractSliceKeys' >> beam.Keys()
                    | 'CountPerSliceKey' >> beam.combiners.Count.PerElement())

    _ = (extracts.pipeline
         | 'IncrementMetricsCallbacksCounters' >>
         counter_util.IncrementMetricsCallbacksCounters(
             eval_shared_model.add_metrics_callbacks), slices_count
         | 'IncreamentSliceSpecCounters' >>
         counter_util.IncrementSliceSpecCounters())

    aggregated_metrics = (
        slices
        # Metrics are computed per slice key.
        # Output: Multi-outputs, a dict of slice key to computed metrics, and
        # plots if applicable.
        | 'ComputePerSliceMetrics' >>
        poisson_bootstrap.ComputeWithConfidenceIntervals(
            aggregate.ComputePerSliceMetrics,
            num_bootstrap_samples=(
                poisson_bootstrap.DEFAULT_NUM_BOOTSTRAP_SAMPLES
                if compute_confidence_intervals else 1),
            random_seed_for_testing=random_seed_for_testing,
            eval_shared_model=eval_shared_model,
            desired_batch_size=desired_batch_size)
        | 'SeparateMetricsAndPlots' >> beam.ParDo(
            _SeparateMetricsAndPlotsFn()).with_outputs(
                _SeparateMetricsAndPlotsFn.OUTPUT_TAG_PLOTS,
                main=_SeparateMetricsAndPlotsFn.OUTPUT_TAG_METRICS))

    return (aggregated_metrics, slices_count)