예제 #1
0
    def assertCombinerOutputEqual(
            self, input_batches: List[types.ValueBatch],
            generator: stats_generator.CombinerFeatureStatsGenerator,
            expected_result: statistics_pb2.FeatureNameStatistics) -> None:
        """Tests a feature combiner statistics generator.

    This runs the generator twice to cover different behavior. There must be at
    least two input batches in order to test the generator's merging behavior.

    Args:
      input_batches: A list of batches of test data.
      generator: The CombinerFeatureStatsGenerator to test.
      expected_result: The FeatureNameStatistics proto that it is expected the
        generator will return.
    """
        # Run generator to check that merge_accumulators() works correctly.
        accumulators = [
            generator.add_input(generator.create_accumulator(),
                                types.FeaturePath(['']), input_batch)
            for input_batch in input_batches
        ]
        result = generator.extract_output(
            generator.merge_accumulators(accumulators))
        compare.assertProtoEqual(self,
                                 result,
                                 expected_result,
                                 normalize_numbers=True)

        # Run generator to check that add_input() works correctly when adding
        # inputs to a non-empty accumulator.
        accumulator = generator.create_accumulator()

        for input_batch in input_batches:
            accumulator = generator.add_input(accumulator,
                                              types.FeaturePath(['']),
                                              input_batch)

        result = generator.extract_output(accumulator)
        compare.assertProtoEqual(self,
                                 result,
                                 expected_result,
                                 normalize_numbers=True)
예제 #2
0
    def assertCombinerOutputEqual(
        self,
        input_batches: List[pa.RecordBatch],
        generator: stats_generator.CombinerFeatureStatsGenerator,
        expected_result: statistics_pb2.FeatureNameStatistics,
        feature_path: types.FeaturePath = types.FeaturePath([''])
    ) -> None:
        """Tests a feature combiner statistics generator.

    This runs the generator twice to cover different behavior. There must be at
    least two input batches in order to test the generator's merging behavior.

    Args:
      input_batches: A list of batches of test data.
      generator: The CombinerFeatureStatsGenerator to test.
      expected_result: The FeatureNameStatistics proto that it is expected the
        generator will return.
      feature_path: The FeaturePath to use, if not specified, will set a
        default value.
    """
        generator.setup()
        # Run generator to check that merge_accumulators() works correctly.
        accumulators = [
            generator.add_input(generator.create_accumulator(), feature_path,
                                input_batch) for input_batch in input_batches
        ]
        # Assume that generators will never be called with empty inputs.
        accumulators = accumulators or [generator.create_accumulator()]
        result = generator.extract_output(
            generator.merge_accumulators(accumulators))
        compare.assertProtoEqual(self,
                                 expected_result,
                                 result,
                                 normalize_numbers=True)

        # Run generator to check that compact() works correctly after
        # merging accumulators.
        accumulators = [
            generator.add_input(generator.create_accumulator(), feature_path,
                                input_batch) for input_batch in input_batches
        ]
        # Assume that generators will never be called with empty inputs.
        accumulators = accumulators or [generator.create_accumulator()]
        result = generator.extract_output(
            generator.compact(generator.merge_accumulators(accumulators)))
        compare.assertProtoEqual(self,
                                 expected_result,
                                 result,
                                 normalize_numbers=True)

        # Run generator to check that add_input() works correctly when adding
        # inputs to a non-empty accumulator.
        accumulator = generator.create_accumulator()

        for input_batch in input_batches:
            accumulator = generator.add_input(accumulator, feature_path,
                                              input_batch)

        result = generator.extract_output(accumulator)
        compare.assertProtoEqual(self,
                                 expected_result,
                                 result,
                                 normalize_numbers=True)