def test_jackknife_combine_per_key(self): def dict_value_sum(dict_elements): """Toy combiner which sums dict values.""" result = collections.defaultdict(int) for dict_element in dict_elements: for k, v in dict_element.items(): result[k] += v return result sliced_extracts = [ (((u'slice_feature1', 1),), { u'label': 0 }), (((u'slice_feature1', 1),), { u'label': 2 }), (((u'slice_feature1', 2),), { u'label': 2 }), (((u'slice_feature1', 2),), { u'label': 4 }), ] with beam.Pipeline() as pipeline: result = ( pipeline | 'Create' >> beam.Create(sliced_extracts, reshuffle=False) | 'JackknifeCombinePerKey' >> jackknife.JackknifeCombinePerKey( beam.combiners.SingleInputTupleCombineFn(dict_value_sum), num_jackknife_samples=2, random_seed=0)) def check_result(got_pcoll): expected_pcoll = [(((u'slice_feature1', 1), (u'_sample_id', -1)), ({ 'label': 2 }, { jackknife._JACKKNIFE_EXAMPLE_COUNT_METRIC_KEY: 2 })), (((u'slice_feature1', 1), (u'_sample_id', 0)), ({ 'label': 2 },)), (((u'slice_feature1', 1), (u'_sample_id', 1)), ({ 'label': 0 },)), (((u'slice_feature1', 2), (u'_sample_id', -1)), ({ 'label': 6 }, { jackknife._JACKKNIFE_EXAMPLE_COUNT_METRIC_KEY: 2 })), (((u'slice_feature1', 2), (u'_sample_id', 0)), ({ 'label': 2 },)), (((u'slice_feature1', 2), (u'_sample_id', 1)), ({ 'label': 4 },))] self.assertCountEqual(expected_pcoll, got_pcoll) util.assert_that(result, check_result)
def _ComputePerSlice( # pylint: disable=invalid-name sliced_extracts: beam.pvalue.PCollection, computations: List[metric_types.MetricComputation], derived_computations: List[metric_types.DerivedMetricComputation], cross_slice_specs: Optional[Iterable[config.CrossSlicingSpec]] = None, compute_with_sampling: Optional[bool] = False, num_jackknife_samples: int = 0, skip_ci_metric_keys: Set[metric_types.MetricKey] = frozenset(), random_seed_for_testing: Optional[int] = None, baseline_model_name: Optional[Text] = None) -> beam.pvalue.PCollection: """PTransform for computing, aggregating and combining metrics and plots. Args: sliced_extracts: Incoming PCollection consisting of slice key and extracts. computations: List of MetricComputations. derived_computations: List of DerivedMetricComputations. cross_slice_specs: List of CrossSlicingSpec. compute_with_sampling: True to compute with bootstrap sampling. This allows _ComputePerSlice to be used to generate unsampled values from the whole data set, as well as bootstrap resamples, in which each element is treated as if it showed up p ~ poission(1) times. num_jackknife_samples: number of delete-d jackknife estimates to use in computing standard errors on metrics. skip_ci_metric_keys: List of metric keys for which to skip confidence interval computation. random_seed_for_testing: Seed to use for unit testing. baseline_model_name: Name for baseline model. Returns: PCollection of (slice key, dict of metrics). """ # TODO(b/123516222): Remove this workaround per discussions in CL/227944001 sliced_extracts.element_type = beam.typehints.Any def convert_and_add_derived_values( sliced_results: Tuple[slicer.SliceKeyType, Tuple[metric_types.MetricsDict, ...]], derived_computations: List[metric_types.DerivedMetricComputation], ) -> Tuple[slicer.SliceKeyType, metric_types.MetricsDict]: """Converts per slice tuple of dicts into single dict and adds derived.""" result = {} for v in sliced_results[1]: result.update(v) for c in derived_computations: result.update(c.result(result)) # Remove private metrics keys = list(result.keys()) for k in keys: if k.name.startswith('_') and not k.name.startswith('__'): result.pop(k) return sliced_results[0], result def add_diff_metrics( sliced_metrics: Tuple[Union[slicer.SliceKeyType, slicer.CrossSliceKeyType], Dict[metric_types.MetricKey, Any]], baseline_model_name: Optional[Text], ) -> Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]]: """Add diff metrics if there is a baseline model.""" result = copy.copy(sliced_metrics[1]) if baseline_model_name: diff_result = {} for k, v in result.items(): if k.model_name != baseline_model_name and k.make_baseline_key( baseline_model_name) in result: # plots will not be diffed. if not isinstance(v, message.Message): diff_result[k.make_diff_key()] = v - result[ k.make_baseline_key(baseline_model_name)] result.update(diff_result) return (sliced_metrics[0], result) combiner = _ComputationsCombineFn( computations=computations, compute_with_sampling=compute_with_sampling, random_seed_for_testing=random_seed_for_testing) if num_jackknife_samples: # We do not use the hotkey fanout hint used by the non-jacknife path because # the random jackknife partitioning naturally mitigates hot keys. sliced_combiner_outputs = ( sliced_extracts | 'JackknifeCombinePerSliceKey' >> jackknife.JackknifeCombinePerKey(combiner, num_jackknife_samples)) else: sliced_combiner_outputs = ( sliced_extracts | 'CombinePerSliceKey' >> beam.CombinePerKey(combiner). with_hot_key_fanout(_COMBINE_PER_SLICE_KEY_HOT_KEY_FANOUT)) sliced_derived_values_and_diffs = ( sliced_combiner_outputs | 'ConvertAndAddDerivedValues' >> beam.Map( convert_and_add_derived_values, derived_computations) | 'AddCrossSliceMetrics' >> _AddCrossSliceMetrics(cross_slice_specs) # pylint: disable=no-value-for-parameter | 'AddDiffMetrics' >> beam.Map(add_diff_metrics, baseline_model_name)) if num_jackknife_samples: return (sliced_derived_values_and_diffs | 'MergeJackknifeSamples' >> jackknife.MergeJackknifeSamples( num_jackknife_samples, skip_ci_metric_keys)) else: return sliced_derived_values_and_diffs