def assertCombinerOutputEqual( self, input_batches: List[types.ValueBatch], generator: stats_generator.CombinerFeatureStatsGenerator, expected_result: statistics_pb2.FeatureNameStatistics) -> None: """Tests a feature combiner statistics generator. This runs the generator twice to cover different behavior. There must be at least two input batches in order to test the generator's merging behavior. Args: input_batches: A list of batches of test data. generator: The CombinerFeatureStatsGenerator to test. expected_result: The FeatureNameStatistics proto that it is expected the generator will return. """ # Run generator to check that merge_accumulators() works correctly. accumulators = [ generator.add_input(generator.create_accumulator(), types.FeaturePath(['']), input_batch) for input_batch in input_batches ] result = generator.extract_output( generator.merge_accumulators(accumulators)) compare.assertProtoEqual(self, result, expected_result, normalize_numbers=True) # Run generator to check that add_input() works correctly when adding # inputs to a non-empty accumulator. accumulator = generator.create_accumulator() for input_batch in input_batches: accumulator = generator.add_input(accumulator, types.FeaturePath(['']), input_batch) result = generator.extract_output(accumulator) compare.assertProtoEqual(self, result, expected_result, normalize_numbers=True)
def assertCombinerOutputEqual( self, input_batches: List[pa.RecordBatch], generator: stats_generator.CombinerFeatureStatsGenerator, expected_result: statistics_pb2.FeatureNameStatistics, feature_path: types.FeaturePath = types.FeaturePath(['']) ) -> None: """Tests a feature combiner statistics generator. This runs the generator twice to cover different behavior. There must be at least two input batches in order to test the generator's merging behavior. Args: input_batches: A list of batches of test data. generator: The CombinerFeatureStatsGenerator to test. expected_result: The FeatureNameStatistics proto that it is expected the generator will return. feature_path: The FeaturePath to use, if not specified, will set a default value. """ generator.setup() # Run generator to check that merge_accumulators() works correctly. accumulators = [ generator.add_input(generator.create_accumulator(), feature_path, input_batch) for input_batch in input_batches ] # Assume that generators will never be called with empty inputs. accumulators = accumulators or [generator.create_accumulator()] result = generator.extract_output( generator.merge_accumulators(accumulators)) compare.assertProtoEqual(self, expected_result, result, normalize_numbers=True) # Run generator to check that compact() works correctly after # merging accumulators. accumulators = [ generator.add_input(generator.create_accumulator(), feature_path, input_batch) for input_batch in input_batches ] # Assume that generators will never be called with empty inputs. accumulators = accumulators or [generator.create_accumulator()] result = generator.extract_output( generator.compact(generator.merge_accumulators(accumulators))) compare.assertProtoEqual(self, expected_result, result, normalize_numbers=True) # Run generator to check that add_input() works correctly when adding # inputs to a non-empty accumulator. accumulator = generator.create_accumulator() for input_batch in input_batches: accumulator = generator.add_input(accumulator, feature_path, input_batch) result = generator.extract_output(accumulator) compare.assertProtoEqual(self, expected_result, result, normalize_numbers=True)