def test_jackknife_combine_per_key(self):

    def dict_value_sum(dict_elements):
      """Toy combiner which sums dict values."""
      result = collections.defaultdict(int)
      for dict_element in dict_elements:
        for k, v in dict_element.items():
          result[k] += v
      return result

    sliced_extracts = [
        (((u'slice_feature1', 1),), {
            u'label': 0
        }),
        (((u'slice_feature1', 1),), {
            u'label': 2
        }),
        (((u'slice_feature1', 2),), {
            u'label': 2
        }),
        (((u'slice_feature1', 2),), {
            u'label': 4
        }),
    ]
    with beam.Pipeline() as pipeline:
      result = (
          pipeline
          | 'Create' >> beam.Create(sliced_extracts, reshuffle=False)
          | 'JackknifeCombinePerKey' >> jackknife.JackknifeCombinePerKey(
              beam.combiners.SingleInputTupleCombineFn(dict_value_sum),
              num_jackknife_samples=2,
              random_seed=0))

      def check_result(got_pcoll):
        expected_pcoll = [(((u'slice_feature1', 1), (u'_sample_id', -1)), ({
            'label': 2
        }, {
            jackknife._JACKKNIFE_EXAMPLE_COUNT_METRIC_KEY: 2
        })), (((u'slice_feature1', 1), (u'_sample_id', 0)), ({
            'label': 2
        },)), (((u'slice_feature1', 1), (u'_sample_id', 1)), ({
            'label': 0
        },)),
                          (((u'slice_feature1', 2), (u'_sample_id', -1)), ({
                              'label': 6
                          }, {
                              jackknife._JACKKNIFE_EXAMPLE_COUNT_METRIC_KEY: 2
                          })),
                          (((u'slice_feature1', 2), (u'_sample_id', 0)), ({
                              'label': 2
                          },)),
                          (((u'slice_feature1', 2), (u'_sample_id', 1)), ({
                              'label': 4
                          },))]
        self.assertCountEqual(expected_pcoll, got_pcoll)

      util.assert_that(result, check_result)
def _ComputePerSlice(  # pylint: disable=invalid-name
        sliced_extracts: beam.pvalue.PCollection,
        computations: List[metric_types.MetricComputation],
        derived_computations: List[metric_types.DerivedMetricComputation],
        cross_slice_specs: Optional[Iterable[config.CrossSlicingSpec]] = None,
        compute_with_sampling: Optional[bool] = False,
        num_jackknife_samples: int = 0,
        skip_ci_metric_keys: Set[metric_types.MetricKey] = frozenset(),
        random_seed_for_testing: Optional[int] = None,
        baseline_model_name: Optional[Text] = None) -> beam.pvalue.PCollection:
    """PTransform for computing, aggregating and combining metrics and plots.

  Args:
    sliced_extracts: Incoming PCollection consisting of slice key and extracts.
    computations: List of MetricComputations.
    derived_computations: List of DerivedMetricComputations.
    cross_slice_specs: List of CrossSlicingSpec.
    compute_with_sampling: True to compute with bootstrap sampling. This allows
      _ComputePerSlice to be used to generate unsampled values from the whole
      data set, as well as bootstrap resamples, in which each element is treated
      as if it showed up p ~ poission(1) times.
    num_jackknife_samples: number of delete-d jackknife estimates to use in
      computing standard errors on metrics.
    skip_ci_metric_keys: List of metric keys for which to skip confidence
      interval computation.
    random_seed_for_testing: Seed to use for unit testing.
    baseline_model_name: Name for baseline model.

  Returns:
    PCollection of (slice key, dict of metrics).
  """
    # TODO(b/123516222): Remove this workaround per discussions in CL/227944001
    sliced_extracts.element_type = beam.typehints.Any

    def convert_and_add_derived_values(
        sliced_results: Tuple[slicer.SliceKeyType,
                              Tuple[metric_types.MetricsDict, ...]],
        derived_computations: List[metric_types.DerivedMetricComputation],
    ) -> Tuple[slicer.SliceKeyType, metric_types.MetricsDict]:
        """Converts per slice tuple of dicts into single dict and adds derived."""
        result = {}
        for v in sliced_results[1]:
            result.update(v)
        for c in derived_computations:
            result.update(c.result(result))
        # Remove private metrics
        keys = list(result.keys())
        for k in keys:
            if k.name.startswith('_') and not k.name.startswith('__'):
                result.pop(k)
        return sliced_results[0], result

    def add_diff_metrics(
        sliced_metrics: Tuple[Union[slicer.SliceKeyType,
                                    slicer.CrossSliceKeyType],
                              Dict[metric_types.MetricKey, Any]],
        baseline_model_name: Optional[Text],
    ) -> Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]]:
        """Add diff metrics if there is a baseline model."""

        result = copy.copy(sliced_metrics[1])

        if baseline_model_name:
            diff_result = {}
            for k, v in result.items():
                if k.model_name != baseline_model_name and k.make_baseline_key(
                        baseline_model_name) in result:
                    # plots will not be diffed.
                    if not isinstance(v, message.Message):
                        diff_result[k.make_diff_key()] = v - result[
                            k.make_baseline_key(baseline_model_name)]
            result.update(diff_result)

        return (sliced_metrics[0], result)

    combiner = _ComputationsCombineFn(
        computations=computations,
        compute_with_sampling=compute_with_sampling,
        random_seed_for_testing=random_seed_for_testing)
    if num_jackknife_samples:
        # We do not use the hotkey fanout hint used by the non-jacknife path because
        # the random jackknife partitioning naturally mitigates hot keys.
        sliced_combiner_outputs = (
            sliced_extracts
            | 'JackknifeCombinePerSliceKey' >>
            jackknife.JackknifeCombinePerKey(combiner, num_jackknife_samples))
    else:
        sliced_combiner_outputs = (
            sliced_extracts
            | 'CombinePerSliceKey' >> beam.CombinePerKey(combiner).
            with_hot_key_fanout(_COMBINE_PER_SLICE_KEY_HOT_KEY_FANOUT))

    sliced_derived_values_and_diffs = (
        sliced_combiner_outputs
        | 'ConvertAndAddDerivedValues' >> beam.Map(
            convert_and_add_derived_values, derived_computations)
        | 'AddCrossSliceMetrics' >> _AddCrossSliceMetrics(cross_slice_specs)  # pylint: disable=no-value-for-parameter
        | 'AddDiffMetrics' >> beam.Map(add_diff_metrics, baseline_model_name))

    if num_jackknife_samples:
        return (sliced_derived_values_and_diffs
                | 'MergeJackknifeSamples' >> jackknife.MergeJackknifeSamples(
                    num_jackknife_samples, skip_ci_metric_keys))
    else:
        return sliced_derived_values_and_diffs