def test_boostrap_sample_combine_fn(self):
    metric_key = metric_types.MetricKey(name='metric')
    samples = [
        confidence_intervals_util.SampleMetrics(
            sample_id=0, metrics={metric_key: 0}),
        confidence_intervals_util.SampleMetrics(
            sample_id=1, metrics={metric_key: 7}),
        confidence_intervals_util.SampleMetrics(
            sample_id=poisson_bootstrap._FULL_SAMPLE_ID,
            metrics={metric_key: 4})
    ]

    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'Create' >> beam.Create(samples, reshuffle=False)
          | 'CombineSamples' >> beam.CombineGlobally(
              poisson_bootstrap._BootstrapSampleCombineFn(
                  num_bootstrap_samples=2)))

      def check_result(got_pcoll):
        self.assertLen(got_pcoll, 1)
        metrics = got_pcoll[0]

        self.assertIn(metric_key, metrics)
        self.assertAlmostEqual(metrics[metric_key].sample_mean, 3.5, delta=0.1)
        self.assertAlmostEqual(
            metrics[metric_key].sample_standard_deviation, 4.94, delta=0.1)
        self.assertEqual(metrics[metric_key].sample_degrees_of_freedom, 1)
        self.assertEqual(metrics[metric_key].unsampled_value, 4.0)

      util.assert_that(result, check_result)
  def test_bootstrap_sample_combine_fn_sample_is_nan(self):
    metric_key = metric_types.MetricKey('metric')
    # the sample value is irrelevant for this test as we only verify counters.
    samples = [
        # unsampled value
        (confidence_intervals_util.SampleMetrics(
            sample_id=poisson_bootstrap._FULL_SAMPLE_ID,
            metrics={
                metric_key: 2,
            })),
        (confidence_intervals_util.SampleMetrics(
            sample_id=0, metrics={metric_key: 2})),
        (confidence_intervals_util.SampleMetrics(
            sample_id=1, metrics={metric_key: float('nan')})),
    ]

    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'Create' >> beam.Create(samples, reshuffle=False)
          | 'CombineSamplesPerKey' >> beam.CombineGlobally(
              poisson_bootstrap._BootstrapSampleCombineFn(
                  num_bootstrap_samples=2)))

      def check_result(got_pcoll):
        self.assertLen(got_pcoll, 1)
        metrics = got_pcoll[0]

        self.assertIn(metric_key, metrics)
        self.assertTrue(np.isnan(metrics[metric_key].sample_mean))
        self.assertTrue(np.isnan(metrics[metric_key].sample_standard_deviation))
        self.assertEqual(metrics[metric_key].sample_degrees_of_freedom, 1)
        self.assertEqual(metrics[metric_key].unsampled_value, 2.0)

      util.assert_that(result, check_result)
    def test_sample_combine_fn_no_input(self):
        slice_key = (('slice_feature', 1), )
        samples = [
            (slice_key,
             confidence_intervals_util.SampleMetrics(sample_id=_FULL_SAMPLE_ID,
                                                     metrics={})),
            (slice_key,
             confidence_intervals_util.SampleMetrics(sample_id=0, metrics={})),
            (slice_key,
             confidence_intervals_util.SampleMetrics(sample_id=1, metrics={})),
        ]

        with beam.Pipeline() as pipeline:
            result = (pipeline
                      | 'Create' >> beam.Create(samples)
                      | 'CombineSamplesPerKey' >> beam.CombinePerKey(
                          _ValidateSampleCombineFn(
                              num_samples=2, full_sample_id=_FULL_SAMPLE_ID)))

            def check_result(got_pcoll):
                self.assertLen(got_pcoll, 1)
                accumulators_by_slice = dict(got_pcoll)
                self.assertIn(slice_key, accumulators_by_slice)
                accumulator = accumulators_by_slice[slice_key]
                self.assertEqual(2, accumulator.num_samples)
                self.assertIsInstance(accumulator.point_estimates, dict)
                self.assertIsInstance(accumulator.metric_samples, dict)

            util.assert_that(result, check_result)
Ejemplo n.º 4
0
def _add_sample_id(  # pylint: disable=invalid-name
    slice_key,
    metrics_dict: metric_types.MetricsDict,
    sample_id: int = 0):
  # sample_id has a default value in order to satisfy requirement of MapTuple
  return slice_key, confidence_intervals_util.SampleMetrics(
      metrics=metrics_dict, sample_id=sample_id)
  def test_boostrap_sample_combine_fn_numpy_overflow(self):
    sample_values = np.random.RandomState(seed=0).randint(0, 1e10, 20)
    metric_key = metric_types.MetricKey('metric')
    samples = [
        confidence_intervals_util.SampleMetrics(
            sample_id=poisson_bootstrap._FULL_SAMPLE_ID,
            metrics={
                metric_key: 1,
            })
    ]
    for sample_id, value in enumerate(sample_values):
      samples.append(
          confidence_intervals_util.SampleMetrics(
              sample_id=sample_id, metrics={
                  metric_key: value,
              }))
    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'Create' >> beam.Create(samples, reshuffle=False)
          | 'CombineSamples' >> beam.CombineGlobally(
              poisson_bootstrap._BootstrapSampleCombineFn(
                  num_bootstrap_samples=20)))

      def check_result(got_pcoll):
        expected_pcoll = [
            {
                metric_key:
                    types.ValueWithTDistribution(
                        sample_mean=5293977041.15,
                        sample_standard_deviation=3023624729.537024,
                        sample_degrees_of_freedom=19,
                        unsampled_value=1),
            },
        ]
        self.assertCountEqual(expected_pcoll, got_pcoll)

      util.assert_that(result, check_result)
Ejemplo n.º 6
0
    def test_jackknife_sample_combine_fn(self):
        x_key = metric_types.MetricKey('x')
        y_key = metric_types.MetricKey('y')
        cm_key = metric_types.MetricKey('confusion_matrix')
        cm_metric = binary_confusion_matrices.Matrices(thresholds=[0.5],
                                                       tp=[0],
                                                       fp=[1],
                                                       tn=[2],
                                                       fn=[3])
        slice_key1 = (('slice_feature', 1), )
        slice_key2 = (('slice_feature', 2), )
        samples = [
            # point estimate for slice 1
            (slice_key1,
             confidence_intervals_util.SampleMetrics(
                 sample_id=jackknife._FULL_SAMPLE_ID,
                 metrics={
                     x_key: 1.6,
                     y_key: 16,
                     cm_key: cm_metric,
                 })),
            # sample values 1 of 2 for slice 1
            (slice_key1,
             confidence_intervals_util.SampleMetrics(sample_id=0,
                                                     metrics={
                                                         x_key: 1,
                                                         y_key: 10,
                                                         cm_key: cm_metric - 1,
                                                     })),
            # sample values 2 of 2 for slice 1
            (slice_key1,
             confidence_intervals_util.SampleMetrics(sample_id=1,
                                                     metrics={
                                                         x_key: 2,
                                                         y_key: 20,
                                                         cm_key: cm_metric + 1,
                                                     })),
            # point estimate for slice 2
            (slice_key2,
             confidence_intervals_util.SampleMetrics(
                 sample_id=jackknife._FULL_SAMPLE_ID,
                 metrics={
                     x_key: 3.3,
                     y_key: 33,
                     cm_key: cm_metric,
                 })),
            # sample values 1 of 2 for slice 2
            (slice_key2,
             confidence_intervals_util.SampleMetrics(sample_id=0,
                                                     metrics={
                                                         x_key: 2,
                                                         y_key: 20,
                                                         cm_key:
                                                         cm_metric - 10,
                                                     })),
            # sample values 2 of 2 for slice 2
            (slice_key2,
             confidence_intervals_util.SampleMetrics(sample_id=1,
                                                     metrics={
                                                         x_key: 4,
                                                         y_key: 40,
                                                         cm_key:
                                                         cm_metric + 10,
                                                     })),
        ]

        with beam.Pipeline() as pipeline:
            result = (pipeline
                      | 'Create' >> beam.Create(samples, reshuffle=False)
                      | 'CombineJackknifeSamplesPerKey' >> beam.CombinePerKey(
                          jackknife._JackknifeSampleCombineFn(
                              num_jackknife_samples=2)))

            # WARNING: Do not change this test without carefully considering the
            # impact on clients due to changed CI bounds. The current implementation
            # follows jackknife cookie bucket method described in:
            # go/rasta-confidence-intervals
            def check_result(got_pcoll):
                expected_pcoll = [
                    (slice_key1, {
                        x_key:
                        types.ValueWithTDistribution(
                            sample_mean=1.5,
                            sample_standard_deviation=0.5,
                            sample_degrees_of_freedom=1,
                            unsampled_value=1.6),
                        y_key:
                        types.ValueWithTDistribution(
                            sample_mean=15.,
                            sample_standard_deviation=5,
                            sample_degrees_of_freedom=1,
                            unsampled_value=16),
                        cm_key:
                        types.ValueWithTDistribution(
                            sample_mean=cm_metric,
                            sample_standard_deviation=(
                                binary_confusion_matrices.Matrices(
                                    thresholds=[0.5],
                                    tp=[1],
                                    fp=[1],
                                    tn=[1],
                                    fn=[1])),
                            sample_degrees_of_freedom=1,
                            unsampled_value=cm_metric),
                    }),
                    (slice_key2, {
                        x_key:
                        types.ValueWithTDistribution(
                            sample_mean=3.,
                            sample_standard_deviation=1,
                            sample_degrees_of_freedom=1,
                            unsampled_value=3.3),
                        y_key:
                        types.ValueWithTDistribution(
                            sample_mean=30.,
                            sample_standard_deviation=10,
                            sample_degrees_of_freedom=1,
                            unsampled_value=33),
                        cm_key:
                        types.ValueWithTDistribution(
                            sample_mean=cm_metric,
                            sample_standard_deviation=(
                                binary_confusion_matrices.Matrices(
                                    thresholds=[0.5],
                                    tp=[10],
                                    fp=[10],
                                    tn=[10],
                                    fn=[10])),
                            sample_degrees_of_freedom=1,
                            unsampled_value=cm_metric),
                    }),
                ]
                self.assertCountEqual(expected_pcoll, got_pcoll)

            util.assert_that(result, check_result)
  def test_boostrap_sample_combine_fn_per_slice(self):
    x_key = metric_types.MetricKey('x')
    y_key = metric_types.MetricKey('y')
    cm_key = metric_types.MetricKey('confusion_matrix')
    cm_metric = binary_confusion_matrices.Matrices(
        thresholds=[0.5], tp=[0], fp=[1], tn=[2], fn=[3])
    skipped_metric_key = metric_types.MetricKey('skipped_metric')
    slice_key1 = (('slice_feature', 1),)
    slice_key2 = (('slice_feature', 2),)
    samples = [
        # unsampled value for slice 1
        (slice_key1,
         confidence_intervals_util.SampleMetrics(
             sample_id=poisson_bootstrap._FULL_SAMPLE_ID,
             metrics={
                 x_key: 1.6,
                 y_key: 16,
                 cm_key: cm_metric,
                 skipped_metric_key: 100,
             })),
        # sample values 1 of 2 for slice 1
        (slice_key1,
         confidence_intervals_util.SampleMetrics(
             sample_id=0,
             metrics={
                 x_key: 1,
                 y_key: 10,
                 cm_key: cm_metric,
                 skipped_metric_key: 45,
             })),
        # sample values 2 of 2 for slice 1
        (slice_key1,
         confidence_intervals_util.SampleMetrics(
             sample_id=1,
             metrics={
                 x_key: 2,
                 y_key: 20,
                 cm_key: cm_metric,
                 skipped_metric_key: 55,
             })),
        # unsampled value for slice 2
        (slice_key2,
         confidence_intervals_util.SampleMetrics(
             sample_id=poisson_bootstrap._FULL_SAMPLE_ID,
             metrics={
                 x_key: 3.3,
                 y_key: 33,
                 cm_key: cm_metric,
                 skipped_metric_key: 1000,
             })),
        # sample values 1 of 2 for slice 2
        (slice_key2,
         confidence_intervals_util.SampleMetrics(
             sample_id=0,
             metrics={
                 x_key: 2,
                 y_key: 20,
                 cm_key: cm_metric,
                 skipped_metric_key: 450,
             })),
        # sample values 2 of 2 for slice 2
        (slice_key2,
         confidence_intervals_util.SampleMetrics(
             sample_id=1,
             metrics={
                 x_key: 4,
                 y_key: 40,
                 cm_key: cm_metric,
                 skipped_metric_key: 550,
             })),
    ]

    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'Create' >> beam.Create(samples, reshuffle=False)
          | 'CombineSamplesPerKey' >> beam.CombinePerKey(
              poisson_bootstrap._BootstrapSampleCombineFn(
                  num_bootstrap_samples=2,
                  skip_ci_metric_keys=[skipped_metric_key])))

      def check_result(got_pcoll):
        expected_pcoll = [
            (
                slice_key1,
                {
                    x_key:
                        types.ValueWithTDistribution(
                            sample_mean=1.5,
                            # sample_standard_deviation=0.5
                            sample_standard_deviation=np.std([1, 2], ddof=1),
                            sample_degrees_of_freedom=1,
                            unsampled_value=1.6),
                    y_key:
                        types.ValueWithTDistribution(
                            sample_mean=15.,
                            # sample_standard_deviation=5,
                            sample_standard_deviation=np.std([10, 20], ddof=1),
                            sample_degrees_of_freedom=1,
                            unsampled_value=16),
                    cm_key:
                        types.ValueWithTDistribution(
                            sample_mean=cm_metric,
                            sample_standard_deviation=cm_metric * 0,
                            sample_degrees_of_freedom=1,
                            unsampled_value=cm_metric),
                    skipped_metric_key:
                        100,
                }),
            (
                slice_key2,
                {
                    x_key:
                        types.ValueWithTDistribution(
                            sample_mean=3.,
                            # sample_standard_deviation=1,
                            sample_standard_deviation=np.std([2, 4], ddof=1),
                            sample_degrees_of_freedom=1,
                            unsampled_value=3.3),
                    y_key:
                        types.ValueWithTDistribution(
                            sample_mean=30.,
                            # sample_standard_deviation=10,
                            sample_standard_deviation=np.std([20, 40], ddof=1),
                            sample_degrees_of_freedom=1,
                            unsampled_value=33),
                    cm_key:
                        types.ValueWithTDistribution(
                            sample_mean=cm_metric,
                            sample_standard_deviation=cm_metric * 0,
                            sample_degrees_of_freedom=1,
                            unsampled_value=cm_metric),
                    skipped_metric_key:
                        1000,
                }),
        ]
        self.assertCountEqual(expected_pcoll, got_pcoll)

      util.assert_that(result, check_result)
  def test_sample_combine_fn(self):
    metric_key = metric_types.MetricKey('metric')
    array_metric_key = metric_types.MetricKey('array_metric')
    missing_sample_metric_key = metric_types.MetricKey('missing_metric')
    non_numeric_metric_key = metric_types.MetricKey('non_numeric_metric')
    non_numeric_array_metric_key = metric_types.MetricKey('non_numeric_array')
    skipped_metric_key = metric_types.MetricKey('skipped_metric')
    slice_key1 = (('slice_feature', 1),)
    slice_key2 = (('slice_feature', 2),)
    # the sample value is irrelevant for this test as we only verify counters.
    samples = [
        # unsampled value for slice 1
        (slice_key1,
         confidence_intervals_util.SampleMetrics(
             sample_id=_FULL_SAMPLE_ID,
             metrics={
                 metric_key: 2.1,
                 array_metric_key: np.array([1, 2]),
                 missing_sample_metric_key: 3,
                 non_numeric_metric_key: 'a',
                 non_numeric_array_metric_key: np.array(['a', 'aaa']),
                 skipped_metric_key: 16
             })),
        # sample values for slice 1
        (slice_key1,
         confidence_intervals_util.SampleMetrics(
             sample_id=0,
             metrics={
                 metric_key: 1,
                 array_metric_key: np.array([2, 3]),
                 missing_sample_metric_key: 2,
                 non_numeric_metric_key: 'b',
                 non_numeric_array_metric_key: np.array(['a', 'aaa']),
                 skipped_metric_key: 7
             })),
        # sample values for slice 1 missing missing_sample_metric_key
        (slice_key1,
         confidence_intervals_util.SampleMetrics(
             sample_id=1,
             metrics={
                 metric_key: 2,
                 array_metric_key: np.array([0, 1]),
                 non_numeric_metric_key: 'c',
                 non_numeric_array_metric_key: np.array(['a', 'aaa']),
                 skipped_metric_key: 8
             })),
        # unsampled value for slice 2
        (slice_key2,
         confidence_intervals_util.SampleMetrics(
             sample_id=_FULL_SAMPLE_ID,
             metrics={
                 metric_key: 6.3,
                 array_metric_key: np.array([10, 20]),
                 missing_sample_metric_key: 6,
                 non_numeric_metric_key: 'd',
                 non_numeric_array_metric_key: np.array(['a', 'aaa']),
                 skipped_metric_key: 10000
             })),
        # Only 1 sample value (missing sample ID 1) for slice 2
        (slice_key2,
         confidence_intervals_util.SampleMetrics(
             sample_id=0,
             metrics={
                 metric_key: 3,
                 array_metric_key: np.array([20, 30]),
                 missing_sample_metric_key: 12,
                 non_numeric_metric_key: 'd',
                 non_numeric_array_metric_key: np.array(['a', 'aaa']),
                 skipped_metric_key: 5000
             })),
    ]

    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'Create' >> beam.Create(samples, reshuffle=False)
          | 'CombineSamplesPerKey' >> beam.CombinePerKey(
              _ValidateSampleCombineFn(
                  num_samples=2,
                  full_sample_id=_FULL_SAMPLE_ID,
                  skip_ci_metric_keys=[skipped_metric_key])))

      def check_result(got_pcoll):
        self.assertLen(got_pcoll, 2)
        accumulators_by_slice = dict(got_pcoll)

        self.assertIn(slice_key1, accumulators_by_slice)
        slice1_accumulator = accumulators_by_slice[slice_key1]
        # check unsampled value
        self.assertIn(metric_key, slice1_accumulator.point_estimates)
        self.assertEqual(2.1, slice1_accumulator.point_estimates[metric_key])
        # check numeric case sample_values
        self.assertIn(metric_key, slice1_accumulator.metric_samples)
        self.assertEqual([1, 2], slice1_accumulator.metric_samples[metric_key])
        # check numeric array in sample_values
        self.assertIn(array_metric_key, slice1_accumulator.metric_samples)
        array_metric_samples = (
            slice1_accumulator.metric_samples[array_metric_key])
        self.assertLen(array_metric_samples, 2)
        testing.assert_array_equal(np.array([2, 3]), array_metric_samples[0])
        testing.assert_array_equal(np.array([0, 1]), array_metric_samples[1])
        # check that non-numeric metric sample_values are not present
        self.assertIn(non_numeric_metric_key,
                      slice1_accumulator.point_estimates)
        self.assertNotIn(non_numeric_metric_key,
                         slice1_accumulator.metric_samples)
        self.assertIn(non_numeric_array_metric_key,
                      slice1_accumulator.point_estimates)
        self.assertNotIn(non_numeric_array_metric_key,
                         slice1_accumulator.metric_samples)
        # check that single metric missing samples generates error
        error_key = metric_types.MetricKey('__ERROR__')
        self.assertIn(error_key, slice1_accumulator.point_estimates)
        self.assertRegex(slice1_accumulator.point_estimates[error_key],
                         'CI not computed for.*missing_metric.*')
        # check that skipped metrics have no samples
        self.assertNotIn(skipped_metric_key, slice1_accumulator.metric_samples)

        self.assertIn(slice_key2, accumulators_by_slice)
        slice2_accumulator = accumulators_by_slice[slice_key2]
        # check unsampled value
        self.assertIn(metric_key, slice2_accumulator.point_estimates)
        self.assertEqual(6.3, slice2_accumulator.point_estimates[metric_key])
        # check that entirely missing sample generates error
        self.assertIn(
            metric_types.MetricKey('__ERROR__'),
            slice2_accumulator.point_estimates)
        self.assertRegex(slice2_accumulator.point_estimates[error_key],
                         'CI not computed because only 1.*Expected 2.*')

      util.assert_that(result, check_result)

      runner_result = pipeline.run()
      # we expect one missing samples counter increment for slice2, since we
      # expected 2 samples, but only saw 1.
      metric_filter = beam.metrics.metric.MetricsFilter().with_name(
          'num_slices_missing_samples')
      counters = runner_result.metrics().query(filter=metric_filter)['counters']
      self.assertLen(counters, 1)
      self.assertEqual(1, counters[0].committed)

      # verify total slice counter
      metric_filter = beam.metrics.metric.MetricsFilter().with_name(
          'num_slices')
      counters = runner_result.metrics().query(filter=metric_filter)['counters']
      self.assertLen(counters, 1)
      self.assertEqual(2, counters[0].committed)