Exemple #1
0
    def _assert_test(self,
                     num_buckets,
                     baseline_examples,
                     comparison_examples,
                     lift_metric_value,
                     ignore_out_of_bound_examples=False):
        eval_config = config.EvalConfig(
            cross_slicing_specs=[config.CrossSlicingSpec()])
        computations = lift.Lift(
            num_buckets=num_buckets,
            ignore_out_of_bound_examples=ignore_out_of_bound_examples
        ).computations(eval_config=eval_config)
        histogram = computations[0]
        lift_metrics = computations[1]

        with beam.Pipeline() as pipeline:
            # pylint: disable=no-value-for-parameter
            baseline_result = (
                pipeline
                | 'CreateB' >> beam.Create(baseline_examples)
                | 'ProcessB' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSliceB' >> beam.Map(lambda x: ((), x))
                | 'ComputeHistogramB' >> beam.CombinePerKey(histogram.combiner)
            )  # pyformat: ignore

            comparison_result = (
                pipeline
                | 'CreateC' >> beam.Create(comparison_examples)
                | 'ProcessC' >> beam.Map(metric_util.to_standard_metric_inputs)
                | 'AddSliceC' >> beam.Map(lambda x: (('slice'), x))
                | 'ComputeHistogramC' >> beam.CombinePerKey(histogram.combiner)
            )  # pyformat: ignore

            # pylint: enable=no-value-for-parameter

            merged_result = ((baseline_result, comparison_result)
                             | 'MergePCollections' >> beam.Flatten())

            def check_result(got):
                try:
                    self.assertLen(got, 2)
                    slice_1, metric_1 = got[0]
                    slice_2, metric_2 = got[1]
                    lift_value = None
                    if not slice_1:
                        lift_value = lift_metrics.cross_slice_comparison(
                            metric_1, metric_2)
                    else:
                        lift_value = lift_metrics.cross_slice_comparison(
                            metric_2, metric_1)

                    self.assertDictElementsAlmostEqual(
                        lift_value, {
                            metric_types.MetricKey(name=f'lift@{num_buckets}'):
                            lift_metric_value,
                        })
                except AssertionError as err:
                    raise util.BeamAssertException(err)

            util.assert_that(merged_result, check_result, label='result')
 def testLift_raisesExceptionWhenCrossSlicingSpecIsAbsent(self):
   with self.assertRaises(ValueError):
     _ = lift.Lift(num_buckets=3).computations(
         eval_config=config_pb2.EvalConfig())
 def testLift_raisesExceptionWhenEvalConfigIsNone(self):
   with self.assertRaises(ValueError):
     _ = lift.Lift(num_buckets=3).computations()