def test_basic_stats_generator_categorical_feature(self):
     batches = [{
         'c': np.array([np.array([1, 5, 10]),
                        np.array([0])])
     }, {
         'c': np.array([np.array([1, 1, 1, 5, 15])])
     }]
     expected_result = {
         'c':
         text_format.Parse(
             """
         name: 'c'
         type: INT
         string_stats {
           common_stats {
             num_non_missing: 3
             num_missing: 0
             min_num_values: 1
             max_num_values: 5
             avg_num_values: 3.0
             tot_num_values: 9
             num_values_histogram {
               buckets {
                 low_value: 1.0
                 high_value: 3.0
                 sample_count: 1.0
               }
               buckets {
                 low_value: 3.0
                 high_value: 5.0
                 sample_count: 1.0
               }
               buckets {
                 low_value: 5.0
                 high_value: 5.0
                 sample_count: 1.0
               }
               type: QUANTILES
             }
           }
           avg_length: 1.22222222
         }
         """, statistics_pb2.FeatureNameStatistics())
     }
     schema = text_format.Parse(
         """
     feature {
       name: "c"
       type: INT
       int_domain {
         is_categorical: true
       }
     }
     """, schema_pb2.Schema())
     generator = basic_stats_generator.BasicStatsGenerator(
         schema=schema,
         num_values_histogram_buckets=3,
         num_histogram_buckets=3,
         num_quantiles_histogram_buckets=4)
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_invalid_value_numpy_dtype(self):
     batches = [{'a': [np.array([1 + 2j])]}]
     generator = basic_stats_generator.BasicStatsGenerator()
     with self.assertRaisesRegexp(
             TypeError,
             'Feature a has value.*, should be int, float or str types.'):
         self.assertCombinerOutputEqual(batches, generator, None)
 def test_basic_stats_generator_invalid_value_type(self):
     batches = [{'a': [{}]}]
     generator = basic_stats_generator.BasicStatsGenerator()
     with self.assertRaisesRegexp(
             TypeError,
             'Feature a has value.*, should be numpy.ndarray or None'):
         self.assertCombinerOutputEqual(batches, generator, None)
예제 #4
0
def _get_default_generators(
        options: stats_options.StatsOptions,
        in_memory: bool = False) -> List[stats_generator.StatsGenerator]:
    """Initializes default list of stats generators.

  Args:
    options: A StatsOptions object.
    in_memory: Whether the generators will be used to generate statistics in
      memory (True) or using Beam (False).

  Returns:
    A list of stats generator objects.
  """
    stats_generators = [
        basic_stats_generator.BasicStatsGenerator(
            schema=options.schema,
            example_weight_map=options.example_weight_map,
            num_values_histogram_buckets=options.num_values_histogram_buckets,
            num_histogram_buckets=options.num_histogram_buckets,
            num_quantiles_histogram_buckets=options.
            num_quantiles_histogram_buckets,
            epsilon=options.epsilon),
    ]
    if options.experimental_use_sketch_based_topk_uniques:
        stats_generators.append(
            top_k_uniques_sketch_stats_generator.
            TopKUniquesSketchStatsGenerator(
                schema=options.schema,
                example_weight_map=options.example_weight_map,
                num_top_values=options.num_top_values,
                num_rank_histogram_buckets=options.num_rank_histogram_buckets,
                frequency_threshold=options.frequency_threshold,
                weighted_frequency_threshold=options.
                weighted_frequency_threshold,
                num_misragries_buckets=_DEFAULT_MG_SKETCH_SIZE,
                num_kmv_buckets=_DEFAULT_KMV_SKETCH_SIZE))
    elif in_memory:
        stats_generators.append(
            top_k_uniques_combiner_stats_generator.
            TopKUniquesCombinerStatsGenerator(
                schema=options.schema,
                example_weight_map=options.example_weight_map,
                num_top_values=options.num_top_values,
                frequency_threshold=options.frequency_threshold,
                weighted_frequency_threshold=options.
                weighted_frequency_threshold,
                num_rank_histogram_buckets=options.num_rank_histogram_buckets))
    else:
        stats_generators.append(
            top_k_uniques_stats_generator.TopKUniquesStatsGenerator(
                schema=options.schema,
                example_weight_map=options.example_weight_map,
                num_top_values=options.num_top_values,
                frequency_threshold=options.frequency_threshold,
                weighted_frequency_threshold=options.
                weighted_frequency_threshold,
                num_rank_histogram_buckets=options.num_rank_histogram_buckets),
        )
    return stats_generators
 def test_basic_stats_generator_categorical_feature(self):
     batches = [
         pa.Table.from_arrays([pa.array([[1, 5, 10], [0]])], ['c']),
         pa.Table.from_arrays([pa.array([[1, 1, 1, 5, 15], [-1]])], ['c']),
     ]
     expected_result = {
         types.FeaturePath(['c']):
         text_format.Parse(
             """
         path {
           step: 'c'
         }
         string_stats {
           common_stats {
             num_non_missing: 4
             min_num_values: 1
             max_num_values: 5
             avg_num_values: 2.5
             num_values_histogram {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 1.3333333
               }
               buckets {
                 low_value: 1.0
                 high_value: 3.0
                 sample_count: 1.3333333
               }
               buckets {
                 low_value: 3.0
                 high_value: 5.0
                 sample_count: 1.3333333
               }
               type: QUANTILES
             }
             tot_num_values: 10
           }
           avg_length: 1.29999995232
         }
         """, statistics_pb2.FeatureNameStatistics())
     }
     schema = text_format.Parse(
         """
     feature {
       name: "c"
       type: INT
       int_domain {
         is_categorical: true
       }
     }
     """, schema_pb2.Schema())
     generator = basic_stats_generator.BasicStatsGenerator(
         schema=schema,
         num_values_histogram_buckets=3,
         num_histogram_buckets=3,
         num_quantiles_histogram_buckets=4)
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_feature_with_different_types(self):
     batches = [
         pa.Table.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0, 5.0]])],
                              ['a']),
         pa.Table.from_arrays([pa.array([[1]])], ['a']),
     ]
     generator = basic_stats_generator.BasicStatsGenerator()
     with self.assertRaisesRegexp(TypeError, 'Cannot determine the type'):
         self.assertCombinerOutputEqual(batches, generator, None)
 def test_basic_stats_generator_invalid_value_numpy_dtype(self):
     batches = [
         pa.Table.from_arrays([pa.array([[]], type=pa.list_(pa.date32()))],
                              ['a'])
     ]
     generator = basic_stats_generator.BasicStatsGenerator()
     with self.assertRaisesRegexp(TypeError,
                                  'Feature a has unsupported arrow type'):
         self.assertCombinerOutputEqual(batches, generator, None)
 def test_basic_stats_generator_feature_with_different_types(self):
     batches = [{
         'a': [np.array([1.0, 2.0]),
               np.array([3.0, 4.0, 5.0])]
     }, {
         'a': [np.array([1])]
     }]
     generator = basic_stats_generator.BasicStatsGenerator()
     with self.assertRaisesRegexp(TypeError, 'Cannot determine the type'):
         self.assertCombinerOutputEqual(batches, generator, None)
 def test_basic_stats_generator_no_value_in_batch(self):
     batches = [
         pa.Table.from_arrays(
             [pa.array([[], [], []], type=pa.list_(pa.int64()))], ['a'])
     ]
     expected_result = {
         types.FeaturePath(['a']):
         text_format.Parse(
             """
         path {
           step: 'a'
         }
         num_stats {
           common_stats {
             num_non_missing: 3
             num_values_histogram {
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               buckets {
                 sample_count: 0.3
               }
               type: QUANTILES
             }
           }
         }""", statistics_pb2.FeatureNameStatistics())
     }
     generator = basic_stats_generator.BasicStatsGenerator()
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_only_nan(self):
     b1 = pa.Table.from_arrays(
         [pa.array([[np.NaN]], type=pa.list_(pa.float32()))], ['a'])
     batches = [b1]
     expected_result = {
         types.FeaturePath(['a']):
         text_format.Parse(
             """
         path {
           step: 'a'
         }
         type: FLOAT
         num_stats {
           common_stats {
             num_non_missing: 1
             min_num_values: 1
             max_num_values: 1
             avg_num_values: 1.0
             tot_num_values: 1
             num_values_histogram {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 0.5
               }
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 0.5
               }
               type: QUANTILES
             }
           }
           histograms {
             num_nan: 1
             type: STANDARD
           }
           histograms {
             num_nan: 1
             type: QUANTILES
           }
         }
         """, statistics_pb2.FeatureNameStatistics())
     }
     generator = basic_stats_generator.BasicStatsGenerator(
         num_values_histogram_buckets=2,
         num_histogram_buckets=3,
         num_quantiles_histogram_buckets=4)
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_no_runtime_warnings_close_to_max_int(self):
     # input has batches with values that are slightly smaller than the maximum
     # integer value.
     less_than_max_int_value = np.iinfo(np.int64).max - 1
     batches = ([
         pa.Table.from_arrays([pa.array([[less_than_max_int_value]])],
                              ['a'])
     ] * 2)
     generator = basic_stats_generator.BasicStatsGenerator()
     with np.testing.assert_no_warnings():
         accumulators = [
             generator.add_input(generator.create_accumulator(), batch)
             for batch in batches
         ]
         generator.merge_accumulators(accumulators)
예제 #12
0
def _get_default_generators(
        options: stats_options.StatsOptions,
        in_memory: bool = False) -> List[stats_generator.StatsGenerator]:
    """Initializes default list of stats generators.

  Args:
    options: A StatsOptions object.
    in_memory: Whether the generators will be used to generate statistics in
      memory (True) or using Beam (False).

  Returns:
    A list of stats generator objects.
  """
    stats_generators = [
        basic_stats_generator.BasicStatsGenerator(
            schema=options.schema,
            weight_feature=options.weight_feature,
            num_values_histogram_buckets=options.num_values_histogram_buckets,
            num_histogram_buckets=options.num_histogram_buckets,
            num_quantiles_histogram_buckets=\
              options.num_quantiles_histogram_buckets,
            epsilon=options.epsilon),
        NumExamplesStatsGenerator(options.weight_feature)
    ]
    if in_memory:
        stats_generators.append(
            top_k_uniques_combiner_stats_generator.
            TopKUniquesCombinerStatsGenerator(
                schema=options.schema,
                weight_feature=options.weight_feature,
                num_top_values=options.num_top_values,
                frequency_threshold=options.frequency_threshold,
                weighted_frequency_threshold=options.
                weighted_frequency_threshold,
                num_rank_histogram_buckets=options.num_rank_histogram_buckets))
    else:
        stats_generators.extend([
            top_k_uniques_stats_generator.TopKUniquesStatsGenerator(
                schema=options.schema,
                weight_feature=options.weight_feature,
                num_top_values=options.num_top_values,
                frequency_threshold=options.frequency_threshold,
                weighted_frequency_threshold=options.
                weighted_frequency_threshold,
                num_rank_histogram_buckets=options.num_rank_histogram_buckets),
        ])
    return stats_generators
 def test_basic_stats_generator_empty_batch(self):
   batches = [{'a': np.array([])}]
   expected_result = {
       'a': text_format.Parse(
           """
           name: 'a'
           type: STRING
           string_stats {
             common_stats {
               num_non_missing: 0
               num_missing: 0
               tot_num_values: 0
             }
           }
           """, statistics_pb2.FeatureNameStatistics())}
   generator = basic_stats_generator.BasicStatsGenerator()
   self.assertCombinerOutputEqual(batches, generator, expected_result)
예제 #14
0
 def test_basic_stats_generator_empty_batch(self):
     batches = [
         pa.Table.from_arrays([pa.array([], type=pa.list_(pa.binary()))],
                              ['a'])
     ]
     expected_result = {
         'a':
         text_format.Parse(
             """
         name: 'a'
         type: STRING
         string_stats {
           common_stats {
             num_non_missing: 0
             tot_num_values: 0
           }
         }
         """, statistics_pb2.FeatureNameStatistics())
     }
     generator = basic_stats_generator.BasicStatsGenerator()
     self.assertCombinerOutputEqual(batches, generator, expected_result)
예제 #15
0
 def test_basic_stats_generator_invalid_value_type(self):
     batches = [{'a': [{}]}]
     generator = basic_stats_generator.BasicStatsGenerator()
     with self.assertRaises(TypeError):
         self.assertCombinerOutputEqual(batches, generator, None)
 def test_basic_stats_generator_handle_null_column(self):
     # Feature 'a' covers null coming before non-null.
     # Feature 'b' covers null coming after non-null.
     b1 = pa.Table.from_arrays([
         pa.array([None, None, None], type=pa.null()),
         pa.array([[1.0, 2.0, 3.0], [4.0], [5.0]]),
     ], ['a', 'b'])
     b2 = pa.Table.from_arrays([
         pa.array([[1, 2], None], type=pa.list_(pa.int64())),
         pa.array([None, None], type=pa.null()),
     ], ['a', 'b'])
     batches = [b1, b2]
     expected_result = {
         types.FeaturePath(['a']):
         text_format.Parse(
             """
         path {
           step: "a"
         }
         num_stats {
           common_stats {
             num_non_missing: 1
             min_num_values: 2
             max_num_values: 2
             avg_num_values: 2.0
             num_values_histogram {
               buckets {
                 low_value: 2.0
                 high_value: 2.0
                 sample_count: 0.25
               }
               buckets {
                 low_value: 2.0
                 high_value: 2.0
                 sample_count: 0.25
               }
               buckets {
                 low_value: 2.0
                 high_value: 2.0
                 sample_count: 0.25
               }
               buckets {
                 low_value: 2.0
                 high_value: 2.0
                 sample_count: 0.25
               }
               type: QUANTILES
             }
             tot_num_values: 2
           }
           mean: 1.5
           std_dev: 0.5
           min: 1.0
           median: 2.0
           max: 2.0
           histograms {
             buckets {
               low_value: 1.0
               high_value: 1.3333333
               sample_count: 0.9955556
             }
             buckets {
               low_value: 1.3333333
               high_value: 1.6666667
               sample_count: 0.0022222
             }
             buckets {
               low_value: 1.6666667
               high_value: 2.0
               sample_count: 1.0022222
             }
           }
           histograms {
             buckets {
               low_value: 1.0
               high_value: 1.0
               sample_count: 0.5
             }
             buckets {
               low_value: 1.0
               high_value: 2.0
               sample_count: 0.5
             }
             buckets {
               low_value: 2.0
               high_value: 2.0
               sample_count: 0.5
             }
             buckets {
               low_value: 2.0
               high_value: 2.0
               sample_count: 0.5
             }
             type: QUANTILES
           }
         }
         """, statistics_pb2.FeatureNameStatistics()),
         types.FeaturePath(['b']):
         text_format.Parse(
             """
         path {
           step: 'b'
         }
         type: FLOAT
         num_stats {
           common_stats {
             num_non_missing: 3
             min_num_values: 1
             max_num_values: 3
             avg_num_values: 1.66666698456
             num_values_histogram {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 1.0
                 high_value: 3.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 3.0
                 high_value: 3.0
                 sample_count: 0.75
               }
               type: QUANTILES
             }
             tot_num_values: 5
           }
           mean: 3.0
           std_dev: 1.4142136
           min: 1.0
           median: 3.0
           max: 5.0
           histograms {
             buckets {
               low_value: 1.0
               high_value: 2.3333333
               sample_count: 1.9888889
             }
             buckets {
               low_value: 2.3333333
               high_value: 3.6666667
               sample_count: 1.0055556
             }
             buckets {
               low_value: 3.6666667
               high_value: 5.0
               sample_count: 2.0055556
             }
           }
           histograms {
             buckets {
               low_value: 1.0
               high_value: 2.0
               sample_count: 1.25
             }
             buckets {
               low_value: 2.0
               high_value: 3.0
               sample_count: 1.25
             }
             buckets {
               low_value: 3.0
               high_value: 4.0
               sample_count: 1.25
             }
             buckets {
               low_value: 4.0
               high_value: 5.0
               sample_count: 1.25
             }
             type: QUANTILES
           }
         }
         """, statistics_pb2.FeatureNameStatistics()),
     }
     generator = basic_stats_generator.BasicStatsGenerator(
         num_values_histogram_buckets=4,
         num_histogram_buckets=3,
         num_quantiles_histogram_buckets=4)
     self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_invalid_value_numpy_dtype(self):
   batches = [{'a': np.array([np.array([1+2j])])}]
   generator = basic_stats_generator.BasicStatsGenerator()
   with self.assertRaises(TypeError):
     self.assertCombinerOutputEqual(batches, generator, None)
예제 #18
0
    def test_tfdv_telemetry(self):
        examples = [{
            'a': np.array([1.0, 2.0], dtype=np.floating),
            'b': np.array(['a', 'b', 'c', 'e'], dtype=np.object),
            'c': None
        }, {
            'a': np.array([3.0, 4.0, np.NaN, 5.0], dtype=np.floating),
            'b': np.array(['d', 'e', 'f'], dtype=np.object),
            'c': None
        }, {
            'a': None,
            'b': np.array(['a', 'b', 'c'], dtype=np.object),
            'c': np.array([10, 20, 30], dtype=np.integer)
        }, {
            'a': np.array([5.0], dtype=np.floating),
            'b': np.array(['d', 'e', 'f'], dtype=np.object),
            'c': np.array([1], dtype=np.integer)
        }]

        p = beam.Pipeline()
        _ = (p
             | 'CreateBatches' >> beam.Create(examples)
             | 'BasicStatsCombiner' >> beam.CombineGlobally(
                 stats_impl._BatchedCombineFnWrapper(
                     basic_stats_generator.BasicStatsGenerator())))

        runner = p.run()
        runner.wait_until_finish()
        result_metrics = runner.metrics()

        num_metrics = len(
            result_metrics.query(
                beam.metrics.metric.MetricsFilter().with_namespace(
                    constants.METRICS_NAMESPACE))['counters'])
        self.assertEqual(num_metrics, 14)

        expected_result = {
            'num_instances': 4,
            'num_missing_feature_values': 3,
            'num_int_feature_values': 2,
            'int_feature_values_min_count': 1,
            'int_feature_values_max_count': 3,
            'int_feature_values_mean_count': 2,
            'num_float_feature_values': 3,
            'float_feature_values_min_count': 1,
            'float_feature_values_max_count': 4,
            'float_feature_values_mean_count': 2,
            'num_string_feature_values': 4,
            'string_feature_values_min_count': 3,
            'string_feature_values_max_count': 4,
            'string_feature_values_mean_count': 3,
        }
        # Check number of counters.
        actual_metrics = result_metrics.query(
            beam.metrics.metric.MetricsFilter().with_namespace(
                constants.METRICS_NAMESPACE))['counters']
        self.assertLen(actual_metrics, len(expected_result))

        # Check each counter.
        for counter_name in expected_result:
            actual_counter = result_metrics.query(
                beam.metrics.metric.MetricsFilter().with_name(
                    counter_name))['counters']
            self.assertLen(actual_counter, 1)
            self.assertEqual(actual_counter[0].committed,
                             expected_result[counter_name])
 def test_basic_stats_generator_column_not_list(self):
     batches = [pa.Table.from_arrays([pa.array([1, 2, 3])], ['a'])]
     generator = basic_stats_generator.BasicStatsGenerator()
     with self.assertRaisesRegexp(TypeError,
                                  'Expected feature column to be a List'):
         self.assertCombinerOutputEqual(batches, generator, None)
 def test_basic_stats_generator_empty_list(self):
   batches = []
   expected_result = {}
   generator = basic_stats_generator.BasicStatsGenerator()
   self.assertCombinerOutputEqual(batches, generator, expected_result)
  def test_basic_stats_generator_with_individual_feature_value_missing(self):
    # input with two batches: first batch has two examples and second batch
    # has a single example.
    batches = [{'a': np.array([np.array([1.0, 2.0]),
                               np.array([3.0, 4.0, np.NaN, 5.0])])},
               {'a': np.array([np.array([np.NaN, 1.0])])}]

    expected_result = {
        'a': text_format.Parse(
            """
            name: 'a'
            type: FLOAT
            num_stats {
              common_stats {
                num_non_missing: 3
                num_missing: 0
                min_num_values: 2
                max_num_values: 4
                avg_num_values: 2.66666666
                tot_num_values: 8
                num_values_histogram {
                  buckets {
                    low_value: 2.0
                    high_value: 2.0
                    sample_count: 1.0
                  }
                  buckets {
                    low_value: 2.0
                    high_value: 4.0
                    sample_count: 1.0
                  }
                  buckets {
                    low_value: 4.0
                    high_value: 4.0
                    sample_count: 1.0
                  }
                  type: QUANTILES
                }
              }
              mean: 2.66666666
              std_dev: 1.49071198
              num_zeros: 0
              min: 1.0
              max: 5.0
              median: 3.0
              histograms {
                num_nan: 2
                buckets {
                  low_value: 1.0
                  high_value: 2.3333333
                  sample_count: 2.9866667
                }
                buckets {
                  low_value: 2.3333333
                  high_value: 3.6666667
                  sample_count: 1.0066667
                }
                buckets {
                  low_value: 3.6666667
                  high_value: 5.0
                  sample_count: 2.0066667
                }
                type: STANDARD
              }
              histograms {
                num_nan: 2
                buckets {
                  low_value: 1.0
                  high_value: 1.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 1.0
                  high_value: 3.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 3.0
                  high_value: 4.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 4.0
                  high_value: 5.0
                  sample_count: 1.5
                }
                type: QUANTILES
              }
            }
            """, statistics_pb2.FeatureNameStatistics())}
    generator = basic_stats_generator.BasicStatsGenerator(
        num_values_histogram_buckets=3, num_histogram_buckets=3,
        num_quantiles_histogram_buckets=4)
    self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_with_multiple_features(self):
   # input with two batches: first batch has two examples and second batch
   # has a single example.
   batches = [{'a': np.array([np.array([1.0, 2.0]),
                              np.array([3.0, 4.0, 5.0])]),
               'b': np.array([np.array(['x', 'y', 'z', 'w']),
                              np.array(['qwe', 'abc'])]),
               'c': np.array([np.linspace(1, 1000, 1000, dtype=np.int32),
                              np.linspace(1001, 2000, 1000, dtype=np.int32)])},
              {'a': np.array([np.array([1.0])]),
               'b': np.array([np.array(['ab'])]),
               'c': np.array([np.linspace(2001, 3000, 1000, dtype=np.int32)])}]
   expected_result = {
       'a': text_format.Parse(
           """
           name: 'a'
           type: FLOAT
           num_stats {
             common_stats {
               num_non_missing: 3
               num_missing: 0
               min_num_values: 1
               max_num_values: 3
               avg_num_values: 2.0
               tot_num_values: 6
               num_values_histogram {
                 buckets {
                   low_value: 1.0
                   high_value: 2.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 2.0
                   high_value: 3.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 3.0
                   high_value: 3.0
                   sample_count: 1.0
                 }
                 type: QUANTILES
               }
             }
             mean: 2.66666666
             std_dev: 1.49071198
             num_zeros: 0
             min: 1.0
             max: 5.0
             median: 3.0
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 2.3333333
                 sample_count: 2.9866667
               }
               buckets {
                 low_value: 2.3333333
                 high_value: 3.6666667
                 sample_count: 1.0066667
               }
               buckets {
                 low_value: 3.6666667
                 high_value: 5.0
                 sample_count: 2.0066667
               }
               type: STANDARD
             }
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 1.0
                 high_value: 3.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 3.0
                 high_value: 4.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 4.0
                 high_value: 5.0
                 sample_count: 1.5
               }
               type: QUANTILES
             }
           }
           """, statistics_pb2.FeatureNameStatistics()),
       'b': text_format.Parse(
           """
           name: 'b'
           type: STRING
           string_stats {
             common_stats {
               num_non_missing: 3
               num_missing: 0
               min_num_values: 1
               max_num_values: 4
               avg_num_values: 2.33333333
               tot_num_values: 7
               num_values_histogram {
                 buckets {
                   low_value: 1.0
                   high_value: 2.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 2.0
                   high_value: 4.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 4.0
                   high_value: 4.0
                   sample_count: 1.0
                 }
                 type: QUANTILES
               }
             }
             avg_length: 1.71428571
           }
           """, statistics_pb2.FeatureNameStatistics()),
       'c': text_format.Parse(
           """
           name: 'c'
           type: INT
           num_stats {
             common_stats {
               num_non_missing: 3
               num_missing: 0
               min_num_values: 1000
               max_num_values: 1000
               avg_num_values: 1000.0
               tot_num_values: 3000
               num_values_histogram {
                 buckets {
                   low_value: 1000.0
                   high_value: 1000.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 1000.0
                   high_value: 1000.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 1000.0
                   high_value: 1000.0
                   sample_count: 1.0
                 }
                 type: QUANTILES
               }
             }
             mean: 1500.5
             std_dev: 866.025355672
             min: 1.0
             max: 3000.0
             median: 1501.0
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 1000.66666667
                 sample_count: 999.666666667
               }
               buckets {
                 low_value: 1000.66666667
                 high_value: 2000.33333333
                 sample_count: 999.666666667
               }
               buckets {
                 low_value: 2000.33333333
                 high_value: 3000.0
                 sample_count: 1000.66666667
               }
               type: STANDARD
             }
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 751.0
                 sample_count: 750.0
               }
               buckets {
                 low_value: 751.0
                 high_value: 1501.0
                 sample_count: 750.0
               }
               buckets {
                 low_value: 1501.0
                 high_value: 2251.0
                 sample_count: 750.0
               }
               buckets {
                 low_value: 2251.0
                 high_value: 3000.0
                 sample_count: 750.0
               }
               type: QUANTILES
             }
           }
           """, statistics_pb2.FeatureNameStatistics())}
   generator = basic_stats_generator.BasicStatsGenerator(
       num_values_histogram_buckets=3, num_histogram_buckets=3,
       num_quantiles_histogram_buckets=4, epsilon=0.001)
   self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_with_entire_feature_value_list_missing(self):
   # input with two batches: first batch has three examples and second batch
   # has two examples.
   batches = [{'a': np.array([np.array([1.0, 2.0]), None,
                              np.array([3.0, 4.0, 5.0])], dtype=np.object),
               'b': np.array([np.array(['x', 'y', 'z', 'w']), None,
                              np.array(['qwe', 'abc'])], dtype=np.object)},
              {'a': np.array([np.array([1.0]), None], dtype=np.object),
               'b': np.array([None, np.array(['qwe'])], dtype=np.object)}]
   expected_result = {
       'a': text_format.Parse(
           """
           name: 'a'
           type: FLOAT
           num_stats {
             common_stats {
               num_non_missing: 3
               num_missing: 2
               min_num_values: 1
               max_num_values: 3
               avg_num_values: 2.0
               tot_num_values: 6
               num_values_histogram {
                 buckets {
                   low_value: 1.0
                   high_value: 2.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 2.0
                   high_value: 3.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 3.0
                   high_value: 3.0
                   sample_count: 1.0
                 }
                 type: QUANTILES
               }
             }
             mean: 2.66666666
             std_dev: 1.49071198
             num_zeros: 0
             min: 1.0
             max: 5.0
             median: 3.0
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 2.3333333
                 sample_count: 2.9866667
               }
               buckets {
                 low_value: 2.3333333
                 high_value: 3.6666667
                 sample_count: 1.0066667
               }
               buckets {
                 low_value: 3.6666667
                 high_value: 5.0
                 sample_count: 2.0066667
               }
               type: STANDARD
             }
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 1.0
                 high_value: 3.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 3.0
                 high_value: 4.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 4.0
                 high_value: 5.0
                 sample_count: 1.5
               }
               type: QUANTILES
             }
           }
           """, statistics_pb2.FeatureNameStatistics()),
       'b': text_format.Parse(
           """
           name: 'b'
           type: STRING
           string_stats {
             common_stats {
               num_non_missing: 3
               num_missing: 2
               min_num_values: 1
               max_num_values: 4
               avg_num_values: 2.33333333
               tot_num_values: 7
               num_values_histogram {
                 buckets {
                   low_value: 1.0
                   high_value: 2.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 2.0
                   high_value: 4.0
                   sample_count: 1.0
                 }
                 buckets {
                   low_value: 4.0
                   high_value: 4.0
                   sample_count: 1.0
                 }
                 type: QUANTILES
               }
             }
             avg_length: 1.85714285
           }
           """, statistics_pb2.FeatureNameStatistics())}
   generator = basic_stats_generator.BasicStatsGenerator(
       num_values_histogram_buckets=3, num_histogram_buckets=3,
       num_quantiles_histogram_buckets=4)
   self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_with_weight_feature(self):
   # input with two batches: first batch has two examples and second batch
   # has a single example.
   batches = [{'a': np.array([np.array([1.0, 2.0]),
                              np.array([3.0, 4.0, 5.0])]),
               'b': np.array([np.array([1, 2]),
                              np.array([3, 4, 5])]),
               'w': np.array([np.array([1.0]), np.array([2.0])])},
              {'a': np.array([np.array([1.0,]), None]),
               'b': np.array([np.array([1]), None]),
               'w': np.array([np.array([3.0]), np.array([2.0])])}]
   expected_result = {
       'a': text_format.Parse(
           """
           name: 'a'
           type: FLOAT
           num_stats {
             common_stats {
               num_non_missing: 3
               num_missing: 1
               min_num_values: 1
               max_num_values: 3
               avg_num_values: 2.0
               tot_num_values: 6
               num_values_histogram {
                 buckets {
                   low_value: 1.0
                   high_value: 1.0
                   sample_count: 0.75
                 }
                 buckets {
                   low_value: 1.0
                   high_value: 2.0
                   sample_count: 0.75
                 }
                 buckets {
                   low_value: 2.0
                   high_value: 3.0
                   sample_count: 0.75
                 }
                 buckets {
                   low_value: 3.0
                   high_value: 3.0
                   sample_count: 0.75
                 }
                 type: QUANTILES
               }
               weighted_common_stats {
                 num_non_missing: 6.0
                 num_missing: 2.0
                 avg_num_values: 1.83333333
                 tot_num_values: 11.0
               }
             }
             mean: 2.66666666
             std_dev: 1.49071198
             num_zeros: 0
             min: 1.0
             max: 5.0
             median: 3.0
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 2.3333333
                 sample_count: 2.9866667
               }
               buckets {
                 low_value: 2.3333333
                 high_value: 3.6666667
                 sample_count: 1.0066667
               }
               buckets {
                 low_value: 3.6666667
                 high_value: 5.0
                 sample_count: 2.0066667
               }
               type: STANDARD
             }
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 1.0
                 high_value: 3.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 3.0
                 high_value: 4.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 4.0
                 high_value: 5.0
                 sample_count: 1.5
               }
               type: QUANTILES
             }
             weighted_numeric_stats {
               mean: 2.7272727
               std_dev: 1.5427784
               median: 3.0
               histograms {
                 buckets {
                   low_value: 1.0
                   high_value: 2.3333333
                   sample_count: 4.9988889
                 }
                 buckets {
                   low_value: 2.3333333
                   high_value: 3.6666667
                   sample_count: 1.9922222
                 }
                 buckets {
                   low_value: 3.6666667
                   high_value: 5.0
                   sample_count: 4.0088889
                 }
               }
               histograms {
                 buckets {
                   low_value: 1.0
                   high_value: 1.0
                   sample_count: 2.75
                 }
                 buckets {
                   low_value: 1.0
                   high_value: 3.0
                   sample_count: 2.75
                 }
                 buckets {
                   low_value: 3.0
                   high_value: 4.0
                   sample_count: 2.75
                 }
                 buckets {
                   low_value: 4.0
                   high_value: 5.0
                   sample_count: 2.75
                 }
                 type: QUANTILES
               }
             }
           }
           """, statistics_pb2.FeatureNameStatistics()),
       'b': text_format.Parse(
           """
           name: 'b'
           type: INT
           num_stats {
             common_stats {
               num_non_missing: 3
               num_missing: 1
               min_num_values: 1
               max_num_values: 3
               avg_num_values: 2.0
               tot_num_values: 6
               num_values_histogram {
                 buckets {
                   low_value: 1.0
                   high_value: 1.0
                   sample_count: 0.75
                 }
                 buckets {
                   low_value: 1.0
                   high_value: 2.0
                   sample_count: 0.75
                 }
                 buckets {
                   low_value: 2.0
                   high_value: 3.0
                   sample_count: 0.75
                 }
                 buckets {
                   low_value: 3.0
                   high_value: 3.0
                   sample_count: 0.75
                 }
                 type: QUANTILES
               }
               weighted_common_stats {
                 num_non_missing: 6.0
                 num_missing: 2.0
                 avg_num_values: 1.83333333
                 tot_num_values: 11.0
               }
             }
             mean: 2.66666666
             std_dev: 1.49071198
             num_zeros: 0
             min: 1.0
             max: 5.0
             median: 3.0
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 2.3333333
                 sample_count: 2.9866667
               }
               buckets {
                 low_value: 2.3333333
                 high_value: 3.6666667
                 sample_count: 1.0066667
               }
               buckets {
                 low_value: 3.6666667
                 high_value: 5.0
                 sample_count: 2.0066667
               }
               type: STANDARD
             }
             histograms {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 1.0
                 high_value: 3.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 3.0
                 high_value: 4.0
                 sample_count: 1.5
               }
               buckets {
                 low_value: 4.0
                 high_value: 5.0
                 sample_count: 1.5
               }
               type: QUANTILES
             }
             weighted_numeric_stats {
               mean: 2.7272727
               std_dev: 1.5427784
               median: 3.0
               histograms {
                 buckets {
                   low_value: 1.0
                   high_value: 2.3333333
                   sample_count: 4.9988889
                 }
                 buckets {
                   low_value: 2.3333333
                   high_value: 3.6666667
                   sample_count: 1.9922222
                 }
                 buckets {
                   low_value: 3.6666667
                   high_value: 5.0
                   sample_count: 4.0088889
                 }
               }
               histograms {
                 buckets {
                   low_value: 1.0
                   high_value: 1.0
                   sample_count: 2.75
                 }
                 buckets {
                   low_value: 1.0
                   high_value: 3.0
                   sample_count: 2.75
                 }
                 buckets {
                   low_value: 3.0
                   high_value: 4.0
                   sample_count: 2.75
                 }
                 buckets {
                   low_value: 4.0
                   high_value: 5.0
                   sample_count: 2.75
                 }
                 type: QUANTILES
               }
             }
           }
           """, statistics_pb2.FeatureNameStatistics())}
   generator = basic_stats_generator.BasicStatsGenerator(
       weight_feature='w',
       num_values_histogram_buckets=4, num_histogram_buckets=3,
       num_quantiles_histogram_buckets=4)
   self.assertCombinerOutputEqual(batches, generator, expected_result)
    def test_basic_stats_generator_with_multiple_features(self):

        b1 = pa.Table.from_arrays([
            pa.array([[1.0, 2.0], [3.0, 4.0, 5.0]]),
            pa.array([[b'x', b'y', b'z', b'w'], [b'qwe', b'abc']]),
            pa.array([
                np.linspace(1, 1000, 1000, dtype=np.int32),
                np.linspace(1001, 2000, 1000, dtype=np.int32)
            ]),
        ], ['a', 'b', 'c'])
        b2 = pa.Table.from_arrays([
            pa.array([[1.0]]),
            pa.array([[b'ab']]),
            pa.array([np.linspace(2001, 3000, 1000, dtype=np.int32)]),
        ], ['a', 'b', 'c'])

        batches = [b1, b2]
        expected_result = {
            types.FeaturePath(['a']):
            text_format.Parse(
                """
            path {
              step: 'a'
            }
            type: FLOAT
            num_stats {
              common_stats {
                num_non_missing: 3
                min_num_values: 1
                max_num_values: 3
                avg_num_values: 2.0
                tot_num_values: 6
                num_values_histogram {
                  buckets {
                    low_value: 1.0
                    high_value: 2.0
                    sample_count: 1.0
                  }
                  buckets {
                    low_value: 2.0
                    high_value: 3.0
                    sample_count: 1.0
                  }
                  buckets {
                    low_value: 3.0
                    high_value: 3.0
                    sample_count: 1.0
                  }
                  type: QUANTILES
                }
              }
              mean: 2.66666666
              std_dev: 1.49071198
              num_zeros: 0
              min: 1.0
              max: 5.0
              median: 3.0
              histograms {
                buckets {
                  low_value: 1.0
                  high_value: 2.3333333
                  sample_count: 2.9866667
                }
                buckets {
                  low_value: 2.3333333
                  high_value: 3.6666667
                  sample_count: 1.0066667
                }
                buckets {
                  low_value: 3.6666667
                  high_value: 5.0
                  sample_count: 2.0066667
                }
                type: STANDARD
              }
              histograms {
                buckets {
                  low_value: 1.0
                  high_value: 1.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 1.0
                  high_value: 3.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 3.0
                  high_value: 4.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 4.0
                  high_value: 5.0
                  sample_count: 1.5
                }
                type: QUANTILES
              }
            }
            """, statistics_pb2.FeatureNameStatistics()),
            types.FeaturePath(['b']):
            text_format.Parse(
                """
            path {
              step: 'b'
            }
            type: STRING
            string_stats {
              common_stats {
                num_non_missing: 3
                min_num_values: 1
                max_num_values: 4
                avg_num_values: 2.33333333
                tot_num_values: 7
                num_values_histogram {
                  buckets {
                    low_value: 1.0
                    high_value: 2.0
                    sample_count: 1.0
                  }
                  buckets {
                    low_value: 2.0
                    high_value: 4.0
                    sample_count: 1.0
                  }
                  buckets {
                    low_value: 4.0
                    high_value: 4.0
                    sample_count: 1.0
                  }
                  type: QUANTILES
                }
              }
              avg_length: 1.71428571
            }
            """, statistics_pb2.FeatureNameStatistics()),
            types.FeaturePath(['c']):
            text_format.Parse(
                """
            path {
              step: 'c'
            }
            type: INT
            num_stats {
              common_stats {
                num_non_missing: 3
                min_num_values: 1000
                max_num_values: 1000
                avg_num_values: 1000.0
                tot_num_values: 3000
                num_values_histogram {
                  buckets {
                    low_value: 1000.0
                    high_value: 1000.0
                    sample_count: 1.0
                  }
                  buckets {
                    low_value: 1000.0
                    high_value: 1000.0
                    sample_count: 1.0
                  }
                  buckets {
                    low_value: 1000.0
                    high_value: 1000.0
                    sample_count: 1.0
                  }
                  type: QUANTILES
                }
              }
              mean: 1500.5
              std_dev: 866.025355672
              min: 1.0
              max: 3000.0
              median: 1501.0
              histograms {
                buckets {
                  low_value: 1.0
                  high_value: 1000.66666667
                  sample_count: 999.666666667
                }
                buckets {
                  low_value: 1000.66666667
                  high_value: 2000.33333333
                  sample_count: 999.666666667
                }
                buckets {
                  low_value: 2000.33333333
                  high_value: 3000.0
                  sample_count: 1000.66666667
                }
                type: STANDARD
              }
              histograms {
                buckets {
                  low_value: 1.0
                  high_value: 751.0
                  sample_count: 750.0
                }
                buckets {
                  low_value: 751.0
                  high_value: 1501.0
                  sample_count: 750.0
                }
                buckets {
                  low_value: 1501.0
                  high_value: 2251.0
                  sample_count: 750.0
                }
                buckets {
                  low_value: 2251.0
                  high_value: 3000.0
                  sample_count: 750.0
                }
                type: QUANTILES
              }
            }
            """, statistics_pb2.FeatureNameStatistics())
        }
        generator = basic_stats_generator.BasicStatsGenerator(
            num_values_histogram_buckets=3,
            num_histogram_buckets=3,
            num_quantiles_histogram_buckets=4,
            epsilon=0.001)
        self.assertCombinerOutputEqual(batches, generator, expected_result)
    def test_basic_stats_generator_with_weight_feature(self):
        # input with two batches: first batch has two examples and second batch
        # has a single example.
        b1 = pa.Table.from_arrays([
            pa.array([[1.0, 2.0], [3.0, 4.0, 5.0]]),
            pa.array([[1, 2], [3, 4, 5]]),
            pa.array([[1.0], [2.0]])
        ], ['a', 'b', 'w'])
        b2 = pa.Table.from_arrays([
            pa.array([[1.0, np.NaN, np.NaN, np.NaN], None]),
            pa.array([[1], None]),
            pa.array([[3.0], [2.0]])
        ], ['a', 'b', 'w'])

        batches = [b1, b2]
        expected_result = {
            types.FeaturePath(['a']):
            text_format.Parse(
                """
            path {
              step: 'a'
            }
            type: FLOAT
            num_stats {
              common_stats {
                num_non_missing: 3
                min_num_values: 2
                max_num_values: 4
                avg_num_values: 3.0
                tot_num_values: 9
                num_values_histogram {
                  buckets {
                    low_value: 2.0
                    high_value: 2.0
                    sample_count: 0.75
                  }
                  buckets {
                    low_value: 2.0
                    high_value: 3.0
                    sample_count: 0.75
                  }
                  buckets {
                    low_value: 3.0
                    high_value: 4.0
                    sample_count: 0.75
                  }
                  buckets {
                    low_value: 4.0
                    high_value: 4.0
                    sample_count: 0.75
                  }
                  type: QUANTILES
                }
                weighted_common_stats {
                  num_non_missing: 6.0
                  avg_num_values: 3.33333333
                  tot_num_values: 20.0
                }
              }
              mean: 2.66666666
              std_dev: 1.49071198
              num_zeros: 0
              min: 1.0
              max: 5.0
              median: 3.0
              histograms {
                num_nan: 3
                buckets {
                  low_value: 1.0
                  high_value: 2.3333333
                  sample_count: 2.9866667
                }
                buckets {
                  low_value: 2.3333333
                  high_value: 3.6666667
                  sample_count: 1.0066667
                }
                buckets {
                  low_value: 3.6666667
                  high_value: 5.0
                  sample_count: 2.0066667
                }
                type: STANDARD
              }
              histograms {
                num_nan: 3
                buckets {
                  low_value: 1.0
                  high_value: 1.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 1.0
                  high_value: 3.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 3.0
                  high_value: 4.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 4.0
                  high_value: 5.0
                  sample_count: 1.5
                }
                type: QUANTILES
              }
              weighted_numeric_stats {
                mean: 2.7272727
                std_dev: 1.5427784
                median: 3.0
                histograms {
                  num_nan: 3
                  buckets {
                    low_value: 1.0
                    high_value: 2.3333333
                    sample_count: 4.9988889
                  }
                  buckets {
                    low_value: 2.3333333
                    high_value: 3.6666667
                    sample_count: 1.9922222
                  }
                  buckets {
                    low_value: 3.6666667
                    high_value: 5.0
                    sample_count: 4.0088889
                  }
                }
                histograms {
                  num_nan: 3
                  buckets {
                    low_value: 1.0
                    high_value: 1.0
                    sample_count: 2.75
                  }
                  buckets {
                    low_value: 1.0
                    high_value: 3.0
                    sample_count: 2.75
                  }
                  buckets {
                    low_value: 3.0
                    high_value: 4.0
                    sample_count: 2.75
                  }
                  buckets {
                    low_value: 4.0
                    high_value: 5.0
                    sample_count: 2.75
                  }
                  type: QUANTILES
                }
              }
            }
            """, statistics_pb2.FeatureNameStatistics()),
            types.FeaturePath(['b']):
            text_format.Parse(
                """
            path {
              step: 'b'
            }
            type: INT
            num_stats {
              common_stats {
                num_non_missing: 3
                min_num_values: 1
                max_num_values: 3
                avg_num_values: 2.0
                tot_num_values: 6
                num_values_histogram {
                  buckets {
                    low_value: 1.0
                    high_value: 1.0
                    sample_count: 0.75
                  }
                  buckets {
                    low_value: 1.0
                    high_value: 2.0
                    sample_count: 0.75
                  }
                  buckets {
                    low_value: 2.0
                    high_value: 3.0
                    sample_count: 0.75
                  }
                  buckets {
                    low_value: 3.0
                    high_value: 3.0
                    sample_count: 0.75
                  }
                  type: QUANTILES
                }
                weighted_common_stats {
                  num_non_missing: 6.0
                  avg_num_values: 1.83333333
                  tot_num_values: 11.0
                }
              }
              mean: 2.66666666
              std_dev: 1.49071198
              num_zeros: 0
              min: 1.0
              max: 5.0
              median: 3.0
              histograms {
                buckets {
                  low_value: 1.0
                  high_value: 2.3333333
                  sample_count: 2.9866667
                }
                buckets {
                  low_value: 2.3333333
                  high_value: 3.6666667
                  sample_count: 1.0066667
                }
                buckets {
                  low_value: 3.6666667
                  high_value: 5.0
                  sample_count: 2.0066667
                }
                type: STANDARD
              }
              histograms {
                buckets {
                  low_value: 1.0
                  high_value: 1.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 1.0
                  high_value: 3.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 3.0
                  high_value: 4.0
                  sample_count: 1.5
                }
                buckets {
                  low_value: 4.0
                  high_value: 5.0
                  sample_count: 1.5
                }
                type: QUANTILES
              }
              weighted_numeric_stats {
                mean: 2.7272727
                std_dev: 1.5427784
                median: 3.0
                histograms {
                  buckets {
                    low_value: 1.0
                    high_value: 2.3333333
                    sample_count: 4.9988889
                  }
                  buckets {
                    low_value: 2.3333333
                    high_value: 3.6666667
                    sample_count: 1.9922222
                  }
                  buckets {
                    low_value: 3.6666667
                    high_value: 5.0
                    sample_count: 4.0088889
                  }
                }
                histograms {
                  buckets {
                    low_value: 1.0
                    high_value: 1.0
                    sample_count: 2.75
                  }
                  buckets {
                    low_value: 1.0
                    high_value: 3.0
                    sample_count: 2.75
                  }
                  buckets {
                    low_value: 3.0
                    high_value: 4.0
                    sample_count: 2.75
                  }
                  buckets {
                    low_value: 4.0
                    high_value: 5.0
                    sample_count: 2.75
                  }
                  type: QUANTILES
                }
              }
            }
            """, statistics_pb2.FeatureNameStatistics())
        }
        generator = basic_stats_generator.BasicStatsGenerator(
            weight_feature='w',
            num_values_histogram_buckets=4,
            num_histogram_buckets=3,
            num_quantiles_histogram_buckets=4)
        self.assertCombinerOutputEqual(batches, generator, expected_result)
 def test_basic_stats_generator_single_feature(self):
     # input with two batches: first batch has two examples and second batch
     # has a single example.
     b1 = pa.Table.from_arrays([pa.array([[1.0, 2.0], [3.0, 4.0, 5.0]])],
                               ['a'])
     b2 = pa.Table.from_arrays([pa.array([[1.0]])], ['a'])
     batches = [b1, b2]
     expected_result = {
         types.FeaturePath(['a']):
         text_format.Parse(
             """
         path {
           step: 'a'
         }
         type: FLOAT
         num_stats {
           common_stats {
             num_non_missing: 3
             min_num_values: 1
             max_num_values: 3
             avg_num_values: 2.0
             tot_num_values: 6
             num_values_histogram {
               buckets {
                 low_value: 1.0
                 high_value: 1.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 1.0
                 high_value: 2.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 2.0
                 high_value: 3.0
                 sample_count: 0.75
               }
               buckets {
                 low_value: 3.0
                 high_value: 3.0
                 sample_count: 0.75
               }
               type: QUANTILES
             }
           }
           mean: 2.66666666
           std_dev: 1.49071198
           num_zeros: 0
           min: 1.0
           max: 5.0
           median: 3.0
           histograms {
             buckets {
               low_value: 1.0
               high_value: 2.3333333
               sample_count: 2.9866667
             }
             buckets {
               low_value: 2.3333333
               high_value: 3.6666667
               sample_count: 1.0066667
             }
             buckets {
               low_value: 3.6666667
               high_value: 5.0
               sample_count: 2.0066667
             }
             type: STANDARD
           }
           histograms {
             buckets {
               low_value: 1.0
               high_value: 1.0
               sample_count: 1.5
             }
             buckets {
               low_value: 1.0
               high_value: 3.0
               sample_count: 1.5
             }
             buckets {
               low_value: 3.0
               high_value: 4.0
               sample_count: 1.5
             }
             buckets {
               low_value: 4.0
               high_value: 5.0
               sample_count: 1.5
             }
             type: QUANTILES
           }
         }
         """, statistics_pb2.FeatureNameStatistics())
     }
     generator = basic_stats_generator.BasicStatsGenerator(
         num_values_histogram_buckets=4,
         num_histogram_buckets=3,
         num_quantiles_histogram_buckets=4)
     self.assertCombinerOutputEqual(batches, generator, expected_result)