def assertCombinerOutputEqual( self, batches: List[types.ExampleBatch], generator: stats_generator.CombinerStatsGenerator, expected_feature_stats: Dict[types.FeaturePath, statistics_pb2.FeatureNameStatistics], expected_cross_feature_stats: Optional[Dict[ types.FeatureCross, statistics_pb2.CrossFeatureStatistics]] = None, only_match_expected_feature_stats: bool = False, ) -> None: """Tests a combiner statistics generator. This runs the generator twice to cover different behavior. There must be at least two input batches in order to test the generator's merging behavior. Args: batches: A list of batches of test data. generator: The CombinerStatsGenerator to test. expected_feature_stats: Dict mapping feature name to FeatureNameStatistics proto that it is expected the generator will return for the feature. expected_cross_feature_stats: Dict mapping feature cross to CrossFeatureStatistics proto that it is expected the generator will return for the feature cross. only_match_expected_feature_stats: if True, will only compare features that appear in `expected_feature_stats`. """ if expected_cross_feature_stats is None: expected_cross_feature_stats = {} def _verify(output): """Verifies that the output meeds the expectations.""" if only_match_expected_feature_stats: features_in_stats = set([ types.FeaturePath.from_proto(f.path) for f in output.features ]) self.assertTrue( set(expected_feature_stats.keys()).issubset( features_in_stats)) else: self.assertEqual( # pylint: disable=g-generic-assert len(output.features), len(expected_feature_stats), '{}, {}'.format(output, expected_feature_stats)) for actual_feature_stats in output.features: actual_path = types.FeaturePath.from_proto( actual_feature_stats.path) expected_stats = expected_feature_stats.get(actual_path) if (only_match_expected_feature_stats and expected_stats is None): continue compare.assertProtoEqual(self, actual_feature_stats, expected_stats, normalize_numbers=True) self.assertEqual( # pylint: disable=g-generic-assert len(result.cross_features), len(expected_cross_feature_stats), '{}, {}'.format(result, expected_cross_feature_stats)) for actual_cross_feature_stats in result.cross_features: cross = (actual_cross_feature_stats.path_x.step[0], actual_cross_feature_stats.path_y.step[0]) compare.assertProtoEqual(self, actual_cross_feature_stats, expected_cross_feature_stats[cross], normalize_numbers=True) # Run generator to check that merge_accumulators() works correctly. accumulators = [ generator.add_input(generator.create_accumulator(), batch) for batch in batches ] result = generator.extract_output( generator.merge_accumulators(accumulators)) _verify(result) # Run generator to check that add_input() works correctly when adding # inputs to a non-empty accumulator. accumulator = generator.create_accumulator() for batch in batches: accumulator = generator.add_input(accumulator, batch) result = generator.extract_output(accumulator) _verify(result)
def assertCombinerOutputEqual( self, batches: List[types.ExampleBatch], generator: stats_generator.CombinerStatsGenerator, expected_feature_stats: Dict[types.FeaturePath, statistics_pb2.FeatureNameStatistics], expected_cross_feature_stats: Optional[Dict[ types.FeatureCross, statistics_pb2.CrossFeatureStatistics]] = None ) -> None: """Tests a combiner statistics generator. This runs the generator twice to cover different behavior. There must be at least two input batches in order to test the generator's merging behavior. Args: batches: A list of batches of test data. generator: The CombinerStatsGenerator to test. expected_feature_stats: Dict mapping feature name to FeatureNameStatistics proto that it is expected the generator will return for the feature. expected_cross_feature_stats: Dict mapping feature cross to CrossFeatureStatistics proto that it is expected the generator will return for the feature cross. """ if expected_cross_feature_stats is None: expected_cross_feature_stats = {} # Run generator to check that merge_accumulators() works correctly. accumulators = [ generator.add_input(generator.create_accumulator(), batch) for batch in batches ] result = generator.extract_output( generator.merge_accumulators(accumulators)) self.assertEqual( # pylint: disable=g-generic-assert len(result.features), len(expected_feature_stats), '{}, {}'.format(result, expected_feature_stats)) for actual_feature_stats in result.features: compare.assertProtoEqual( self, actual_feature_stats, expected_feature_stats[types.FeaturePath.from_proto( actual_feature_stats.path)], normalize_numbers=True) self.assertEqual( # pylint: disable=g-generic-assert len(result.cross_features), len(expected_cross_feature_stats), '{}, {}'.format(result, expected_cross_feature_stats)) for actual_cross_feature_stats in result.cross_features: cross = (actual_cross_feature_stats.path_x.step[0], actual_cross_feature_stats.path_y.step[0]) compare.assertProtoEqual( self, actual_cross_feature_stats, expected_cross_feature_stats[cross], normalize_numbers=True) # Run generator to check that add_input() works correctly when adding # inputs to a non-empty accumulator. accumulator = generator.create_accumulator() for batch in batches: accumulator = generator.add_input(accumulator, batch) result = generator.extract_output(accumulator) self.assertEqual(len(result.features), len(expected_feature_stats)) # pylint: disable=g-generic-assert for actual_feature_stats in result.features: compare.assertProtoEqual( self, actual_feature_stats, expected_feature_stats[types.FeaturePath.from_proto( actual_feature_stats.path)], normalize_numbers=True) self.assertEqual(len(result.cross_features), len(expected_cross_feature_stats)) # pylint: disable=g-generic-assert for actual_cross_feature_stats in result.cross_features: cross = (actual_cross_feature_stats.path_x.step[0], actual_cross_feature_stats.path_y.step[0]) compare.assertProtoEqual( self, actual_cross_feature_stats, expected_cross_feature_stats[cross], normalize_numbers=True)