def test_all_null_mask_all_null(self): batch = input_batch.InputBatch( pa.Table.from_arrays([ pa.array([None, None], type=pa.null()), pa.array([None, None], type=pa.null()) ], ['f1', 'f2'])) path1 = types.FeaturePath(['f1']) path2 = types.FeaturePath(['f2']) expected_mask = np.array([True, True]) np.testing.assert_array_equal( batch.all_null_mask(path1, path2), expected_mask)
def test_list_lengths_null_array(self): batch = input_batch.InputBatch( pa.Table.from_arrays([ pa.array([None, None, None], type=pa.null()), ], ['f1'])) np.testing.assert_array_equal( batch.list_lengths(types.FeaturePath(['f1'])), [0, 0, 0])
def test_lift_null_y(self): examples = [ pa.Table.from_arrays([ pa.array([['a'], ['a'], ['b'], ['a']]), pa.array([None, None, None, None], type=pa.null()), ], ['categorical_x', 'string_y']), ] schema = text_format.Parse( """ feature { name: 'categorical_x' type: BYTES } feature { name: 'string_y' type: BYTES } """, schema_pb2.Schema()) expected_result = [] generator = lift_stats_generator.LiftStatsGenerator( schema=schema, y_path=types.FeaturePath(['string_y'])) self.assertSlicingAwareTransformOutputEqual( examples, generator, expected_result, add_default_slice_key_to_input=True, add_default_slice_key_to_output=True)
def test_example_value_presence_null_array(self): t = pa.Table.from_arrays([ pa.array([None, None], type=pa.null()), ], ['x']) self.assertIsNone( lift_stats_generator._get_example_value_presence( t, types.FeaturePath(['x']), boundaries=None))
def _GetEmptyTable(num_rows: int) -> pa.Table: # pyarrow doesn't provide an API to create a table with zero column but non # zero rows. We work around it by adding a dummy column first and then # removing it. t = pa.Table.from_arrays([pa.array([None] * num_rows, type=pa.null())], ["dummy"]) return t.remove_column(0)
def test_null_mask_null_array(self): batch = input_batch.InputBatch( pa.Table.from_arrays([pa.array([None], type=pa.null())], ['feature'])) path = types.FeaturePath(['feature']) expected_mask = np.array([True]) np.testing.assert_array_equal(batch.null_mask(path), expected_mask)
def _process_column_infos(self, column_infos: List[csv_decoder.ColumnInfo]): column_handlers = [] column_arrow_types = [] for c in column_infos: if c.type == statistics_pb2.FeatureNameStatistics.INT: column_handlers.append(lambda v: (int(v),)) column_arrow_types.append(pa.list_(pa.int64())) elif c.type == statistics_pb2.FeatureNameStatistics.FLOAT: column_handlers.append(lambda v: (float(v),)) column_arrow_types.append(pa.list_(pa.float32())) elif c.type == statistics_pb2.FeatureNameStatistics.STRING: column_handlers.append(lambda v: (v,)) column_arrow_types.append(pa.list_(pa.binary())) else: column_handlers.append(lambda _: None) column_arrow_types.append(pa.null()) self._column_handlers = column_handlers self._column_arrow_types = column_arrow_types self._column_names = [c.name for c in column_infos]
class BinArrayTest(parameterized.TestCase): """Tests for bin_array.""" @parameterized.named_parameters([ ('simple', pa.array([0.1, 0.5, 0.75]), [0.25, 0.75], [0, 1, 2], [0, 1, 2]), ('negative_values', pa.array([-0.8, -0.5, -0.1]), [0.25], [0, 1, 2], [0, 0, 0]), ('inf_values', pa.array([float('-inf'), 0.5, float('inf')]), [0.25, 0.75], [0, 1, 2], [0, 1, 2]), ('nan_values', pa.array([np.nan, 0.5]), [0.25, 0.75], [1], [1]), ('negative_boundaries', pa.array([-0.8, -0.5]), [-0.75, -0.25], [0, 1], [0, 1]), ('empty_array', pa.array([]), [0.25], [], []), ('none_value', pa.array([None, 0.5]), [0.25], [1], [1]), ('null_array', pa.array([None, None], type=pa.null()), [0.25], [], []) ]) def test_bin_array(self, array, boundaries, expected_indices, expected_bins): indices, bins = bin_util.bin_array(array, boundaries) np.testing.assert_array_equal(expected_indices, indices) np.testing.assert_array_equal(expected_bins, bins)
def test_lift_slice_aware(self): examples = [ ('slice1', pa.Table.from_arrays([ pa.array([['a'], ['a'], ['b'], ['a']]), pa.array([['cat'], ['dog'], ['cat'], ['dog']]), ], ['categorical_x', 'string_y'])), ('slice2', pa.Table.from_arrays([ pa.array([['a'], ['a'], ['a']]), pa.array([['cat'], ['dog'], ['dog']]), ], ['categorical_x', 'string_y'])), ('slice1', pa.Table.from_arrays([ pa.array([['a'], ['a'], ['b'], ['a']]), pa.array([['cat'], ['dog'], ['cat'], ['dog']]), ], ['categorical_x', 'string_y'])), ('slice2', pa.Table.from_arrays([ pa.array([None, None, None, None], type=pa.null()), pa.array([['cat'], ['dog'], ['cat'], ['dog']]), ], ['categorical_x', 'string_y'])), ] schema = text_format.Parse( """ feature { name: 'categorical_x' type: BYTES } feature { name: 'string_y' type: BYTES } """, schema_pb2.Schema()) expected_result = [ ('slice1', text_format.Parse( """ cross_features { path_x { step: "categorical_x" } path_y { step: "string_y" } categorical_cross_stats { lift_series { y_string: "cat" y_count: 4 lift_values { x_string: "b" lift: 2.0 x_count: 2 x_and_y_count: 2 } lift_values { x_string: "a" lift: 0.666666984558 x_count: 6 x_and_y_count: 2 } } lift_series { y_string: "dog" y_count: 4 lift_values { x_string: "a" lift: 1.33333301544 x_count: 6 x_and_y_count: 4 } lift_values { x_string: "b" lift: 0.0 x_count: 2 x_and_y_count: 0 } } } }""", statistics_pb2.DatasetFeatureStatistics())), ('slice2', text_format.Parse( """ cross_features { path_x { step: "categorical_x" } path_y { step: "string_y" } categorical_cross_stats { lift_series { y_string: "cat" y_count: 3 lift_values { x_string: "a" lift: 0.777778029441 x_count: 3 x_and_y_count: 1 } } lift_series { y_string: "dog" y_count: 4 lift_values { x_string: "a" lift: 1.16666698455 x_count: 3 x_and_y_count: 2 } } } }""", statistics_pb2.DatasetFeatureStatistics())), ] generator = lift_stats_generator.LiftStatsGenerator( schema=schema, y_path=types.FeaturePath(['string_y'])) self.assertSlicingAwareTransformOutputEqual(examples, generator, expected_result)
def test_topk_uniques_with_categorical_feature(self): examples = [ pa.Table.from_arrays( [pa.array([[12, 23, 34, 12], [45, 23], [12, 12, 34, 45]])], ['fa']), pa.Table.from_arrays([pa.array([None, None], type=pa.null())], ['fa']) ] expected_result = [ text_format.Parse( """ features { path { step: 'fa' } type: INT string_stats { top_values { value: '12' frequency: 4 } top_values { value: '45' frequency: 2 } rank_histogram { buckets { low_rank: 0 high_rank: 0 label: "12" sample_count: 4.0 } buckets { low_rank: 1 high_rank: 1 label: "45" sample_count: 2.0 } buckets { low_rank: 2 high_rank: 2 label: "34" sample_count: 2.0 } } } }""", statistics_pb2.DatasetFeatureStatistics()), text_format.Parse( """ features { path { step: 'fa' } type: INT string_stats { unique: 4 } }""", statistics_pb2.DatasetFeatureStatistics()), ] schema = text_format.Parse( """ feature { name: "fa" type: INT int_domain { is_categorical: true } } """, schema_pb2.Schema()) generator = top_k_uniques_stats_generator.TopKUniquesStatsGenerator( schema=schema, num_top_values=2, num_rank_histogram_buckets=3) self.assertSlicingAwareTransformOutputEqual( examples, generator, expected_result, add_default_slice_key_to_input=True, add_default_slice_key_to_output=True)
pa.array([None, [1., 2., 3.], None, None], pa.list_(pa.float32())), "f3": pa.array([None, None, [b"abc", b"def"], None], pa.list_(pa.binary())), "f4": pa.array([None, None, None, [8]], pa.list_(pa.int64())), }), dict(testcase_name="null_array", input_examples=[{ "a": None, }, { "a": None, }], expected_output={ "a": pa.array([None, None], type=pa.null()), }) ] class DecodedExamplesToArrowPyTest(parameterized.TestCase): @parameterized.named_parameters(*_INVALID_INPUT_TEST_CASES) def test_invalid_input(self, test_input, expected_error, expected_error_regexp): with self.assertRaisesRegex(expected_error, expected_error_regexp): decoded_examples_to_arrow.DecodedExamplesToTablePy(test_input) @parameterized.named_parameters(*_CONVERSION_TEST_CASES) def test_conversion(self, input_examples, expected_output): table = decoded_examples_to_arrow.DecodedExamplesToTablePy( input_examples)
def test_sparse_feature_generator_multiple_sparse_features(self): batches = [ pa.Table.from_arrays([ pa.array([ None, None, ['a', 'b'], ['a', 'b'], ['a', 'b'], None, None ]), pa.array([[1, 2], [1, 2], None, None, None, None, None]), pa.array([[2, 4], [2, 4], [2, 4, 6], [2, 4, 6], [2, 4, 6], None, None]), pa.array( [None, None, None, None, None, ['a', 'b'], ['a', 'b']]), pa.array([None, None, None, None, None, [2, 4], [2, 4]]), pa.array([None, None, None, None, None, None, None], type=pa.null()), ], [ 'value_feature', 'index_feature1', 'index_feature2', 'other_value_feature', 'other_index_feature1', 'other_index_feature2' ]), pa.Table.from_arrays([ pa.array( [None, None, None, None, None, ['a', 'b'], ['a', 'b']]), pa.array([None, None, None, None, None, [2, 4], [2, 4]]), pa.array([None, None, None, None, None, None, None], type=pa.null()) ], [ 'other_value_feature', 'other_index_feature1', 'other_index_feature2' ]), ] schema = text_format.Parse( """ sparse_feature { name: 'sparse_feature' index_feature { name: 'index_feature1' } index_feature { name: 'index_feature2' } value_feature { name: 'value_feature' } } sparse_feature { name: 'other_sparse_feature' index_feature { name: 'other_index_feature1' } index_feature { name: 'other_index_feature2' } value_feature { name: 'other_value_feature' } } """, schema_pb2.Schema()) expected_result = { types.FeaturePath(['sparse_feature']): text_format.Parse( """ path { step: 'sparse_feature' } custom_stats { name: 'missing_value' num: 2 } custom_stats { name: 'missing_index' rank_histogram { buckets { label: 'index_feature1' sample_count: 3 } buckets { label: 'index_feature2' sample_count: 0 } } } custom_stats { name: 'max_length_diff' rank_histogram { buckets { label: 'index_feature1' sample_count: 2 } buckets { label: 'index_feature2' sample_count: 2 } } } custom_stats { name: 'min_length_diff' rank_histogram { buckets { label: 'index_feature1' sample_count: -2 } buckets { label: 'index_feature2' sample_count: 1 } } }""", statistics_pb2.FeatureNameStatistics()), types.FeaturePath(['other_sparse_feature']): text_format.Parse( """ path { step: 'other_sparse_feature' } custom_stats { name: 'missing_value' num: 0 } custom_stats { name: 'missing_index' rank_histogram { buckets { label: 'other_index_feature1' sample_count: 0 } buckets { label: 'other_index_feature2' sample_count: 4 } } } custom_stats { name: 'max_length_diff' rank_histogram { buckets { label: 'other_index_feature1' sample_count: 0 } buckets { label: 'other_index_feature2' sample_count: -2 } } } custom_stats { name: 'min_length_diff' rank_histogram { buckets { label: 'other_index_feature1' sample_count: 0 } buckets { label: 'other_index_feature2' sample_count: -2 } } }""", statistics_pb2.FeatureNameStatistics()) } generator = ( sparse_feature_stats_generator.SparseFeatureStatsGenerator(schema)) self.assertCombinerOutputEqual(batches, generator, expected_result)
def test_basic_stats_generator_handle_null_column(self): # Feature 'a' covers null coming before non-null. # Feature 'b' covers null coming after non-null. b1 = pa.Table.from_arrays([ pa.array([None, None, None], type=pa.null()), pa.array([[1.0, 2.0, 3.0], [4.0], [5.0]]), ], ['a', 'b']) b2 = pa.Table.from_arrays([ pa.array([[1, 2], None], type=pa.list_(pa.int64())), pa.array([None, None], type=pa.null()), ], ['a', 'b']) batches = [b1, b2] expected_result = { types.FeaturePath(['a']): text_format.Parse( """ path { step: "a" } num_stats { common_stats { num_non_missing: 1 min_num_values: 2 max_num_values: 2 avg_num_values: 2.0 num_values_histogram { buckets { low_value: 2.0 high_value: 2.0 sample_count: 0.25 } buckets { low_value: 2.0 high_value: 2.0 sample_count: 0.25 } buckets { low_value: 2.0 high_value: 2.0 sample_count: 0.25 } buckets { low_value: 2.0 high_value: 2.0 sample_count: 0.25 } type: QUANTILES } tot_num_values: 2 } mean: 1.5 std_dev: 0.5 min: 1.0 median: 2.0 max: 2.0 histograms { buckets { low_value: 1.0 high_value: 1.3333333 sample_count: 0.9955556 } buckets { low_value: 1.3333333 high_value: 1.6666667 sample_count: 0.0022222 } buckets { low_value: 1.6666667 high_value: 2.0 sample_count: 1.0022222 } } histograms { buckets { low_value: 1.0 high_value: 1.0 sample_count: 0.5 } buckets { low_value: 1.0 high_value: 2.0 sample_count: 0.5 } buckets { low_value: 2.0 high_value: 2.0 sample_count: 0.5 } buckets { low_value: 2.0 high_value: 2.0 sample_count: 0.5 } type: QUANTILES } } """, statistics_pb2.FeatureNameStatistics()), types.FeaturePath(['b']): text_format.Parse( """ path { step: 'b' } type: FLOAT num_stats { common_stats { num_non_missing: 3 min_num_values: 1 max_num_values: 3 avg_num_values: 1.66666698456 num_values_histogram { buckets { low_value: 1.0 high_value: 1.0 sample_count: 0.75 } buckets { low_value: 1.0 high_value: 1.0 sample_count: 0.75 } buckets { low_value: 1.0 high_value: 3.0 sample_count: 0.75 } buckets { low_value: 3.0 high_value: 3.0 sample_count: 0.75 } type: QUANTILES } tot_num_values: 5 } mean: 3.0 std_dev: 1.4142136 min: 1.0 median: 3.0 max: 5.0 histograms { buckets { low_value: 1.0 high_value: 2.3333333 sample_count: 1.9888889 } buckets { low_value: 2.3333333 high_value: 3.6666667 sample_count: 1.0055556 } buckets { low_value: 3.6666667 high_value: 5.0 sample_count: 2.0055556 } } histograms { buckets { low_value: 1.0 high_value: 2.0 sample_count: 1.25 } buckets { low_value: 2.0 high_value: 3.0 sample_count: 1.25 } buckets { low_value: 3.0 high_value: 4.0 sample_count: 1.25 } buckets { low_value: 4.0 high_value: 5.0 sample_count: 1.25 } type: QUANTILES } } """, statistics_pb2.FeatureNameStatistics()), } generator = basic_stats_generator.BasicStatsGenerator( num_values_histogram_buckets=4, num_histogram_buckets=3, num_quantiles_histogram_buckets=4) self.assertCombinerOutputEqual(batches, generator, expected_result)
def test_all_null_mask_no_paths(self): batch = input_batch.InputBatch( pa.Table.from_arrays([pa.array([None, [1]], type=pa.null())], ['f3'])) with self.assertRaisesRegex(ValueError, r'Paths cannot be empty.*'): batch.all_null_mask()
def test_topk_uniques_combiner_with_categorical_feature(self): # fa: 4 12, 2 23, 2 34, 2 45 batches = [ pa.Table.from_arrays([pa.array([[12, 23, 34, 12], [45, 23]])], ['fa']), pa.Table.from_arrays([pa.array([[12, 12, 34, 45]])], ['fa']), pa.Table.from_arrays( [pa.array([None, None, None, None], type=pa.null())], ['fa']), ] expected_result = { types.FeaturePath(['fa']): text_format.Parse( """ path { step: 'fa' } type: INT string_stats { unique: 4 top_values { value: '12' frequency: 4 } top_values { value: '45' frequency: 2 } top_values { value: '34' frequency: 2 } top_values { value: '23' frequency: 2 } rank_histogram { buckets { low_rank: 0 high_rank: 0 label: "12" sample_count: 4.0 } buckets { low_rank: 1 high_rank: 1 label: "45" sample_count: 2.0 } buckets { low_rank: 2 high_rank: 2 label: "34" sample_count: 2.0 } } }""", statistics_pb2.FeatureNameStatistics()) } schema = text_format.Parse( """ feature { name: "fa" type: INT int_domain { is_categorical: true } } """, schema_pb2.Schema()) generator = ( top_k_uniques_combiner_stats_generator .TopKUniquesCombinerStatsGenerator( schema=schema, num_top_values=4, num_rank_histogram_buckets=3)) self.assertCombinerOutputEqual(batches, generator, expected_result)
class WeightedFeatureStatsGeneratorTest(parameterized.TestCase, test_util.CombinerStatsGeneratorTest): @parameterized.named_parameters( { 'testcase_name': 'AllMatching', 'batches': [ pa.Table.from_arrays( [pa.array([['a'], ['a', 'b']]), pa.array([[2], [2, 4]])], ['value', 'weight']) ], 'expected_missing_weight': 0.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': 0.0, 'expected_max_weight_length_diff': 0.0 }, { 'testcase_name': 'AllMatchingMultiBatch', 'batches': [ pa.Table.from_arrays( [pa.array([['a'], ['a', 'b']]), pa.array([[2], [2, 4]])], ['value', 'weight']), pa.Table.from_arrays( [pa.array([['a'], ['a', 'b']]), pa.array([[2], [2, 4]])], ['value', 'weight']) ], 'expected_missing_weight': 0.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': 0.0, 'expected_max_weight_length_diff': 0.0 }, { 'testcase_name': 'LengthMismatchPositive', 'batches': [ pa.Table.from_arrays( [pa.array([['a'], ['a']]), pa.array([[2], [2, 4]])], ['value', 'weight']) ], 'expected_missing_weight': 0.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': 0.0, 'expected_max_weight_length_diff': 1.0 }, { 'testcase_name': 'LengthMismatchNegative', 'batches': [ pa.Table.from_arrays( [pa.array([['a'], ['a', 'b']]), pa.array([[2], [2]])], ['value', 'weight']) ], 'expected_missing_weight': 0.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': -1.0, 'expected_max_weight_length_diff': 0.0 }, { 'testcase_name': 'LengthMismatchMultiBatch', 'batches': [ pa.Table.from_arrays( [pa.array([['a'], ['a', 'b']]), pa.array([[], []])], ['value', 'weight']), pa.Table.from_arrays([pa.array([[1], [1, 1]])], ['other']) ], 'expected_missing_weight': 0.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': -2.0, 'expected_max_weight_length_diff': -1.0 }, { 'testcase_name': 'SomePairsMissing', 'batches': [ pa.Table.from_arrays([ pa.array([['a'], None, ['a', 'b']]), pa.array([[1, 1], None, [1, 1, 1]]) ], ['value', 'weight']) ], 'expected_missing_weight': 0.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': 1.0, 'expected_max_weight_length_diff': 1.0 }, { 'testcase_name': 'EmptyWeights', 'batches': [pa.Table.from_arrays([pa.array([['a'], ['a', 'b']])], ['value'])], 'expected_missing_weight': 2.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': -2.0, 'expected_max_weight_length_diff': -1.0 }, { 'testcase_name': 'EmptyValues', 'batches': [pa.Table.from_arrays([pa.array([[1], [1, 2]])], ['weight'])], 'expected_missing_weight': 0.0, 'expected_missing_value': 2.0, 'expected_min_weight_length_diff': 1.0, 'expected_max_weight_length_diff': 2.0 }, { 'testcase_name': 'EmptyWeightsAndValues', 'batches': [pa.Table.from_arrays([])], 'expected_missing_weight': 0.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': 0.0, 'expected_max_weight_length_diff': 0.0 }, { 'testcase_name': 'NullWeightArray', 'batches': [ pa.Table.from_arrays([ pa.array([['a'], ['a', 'b']]), pa.array([None, None], type=pa.null()) ], ['value', 'weight']) ], 'expected_missing_weight': 2.0, 'expected_missing_value': 0.0, 'expected_min_weight_length_diff': -2.0, 'expected_max_weight_length_diff': -1.0 }) def test_single_weighted_feature(self, batches, expected_missing_weight, expected_missing_value, expected_min_weight_length_diff, expected_max_weight_length_diff): schema = text_format.Parse( """ weighted_feature { name: 'weighted_feature' feature { step: 'value' } weight_feature { step: 'weight' } } """, schema_pb2.Schema()) generator = (weighted_feature_stats_generator. WeightedFeatureStatsGenerator(schema)) expected_stats = statistics_pb2.FeatureNameStatistics() expected_stats.path.step.append('weighted_feature') expected_stats.custom_stats.add(name='missing_weight', num=expected_missing_weight) expected_stats.custom_stats.add(name='missing_value', num=expected_missing_value) expected_stats.custom_stats.add(name='min_weight_length_diff', num=expected_min_weight_length_diff) expected_stats.custom_stats.add(name='max_weight_length_diff', num=expected_max_weight_length_diff) expected_result = { types.FeaturePath(['weighted_feature']): expected_stats } self.assertCombinerOutputEqual(batches, generator, expected_result) def test_shared_weight(self): batches = [ pa.Table.from_arrays([ pa.array([['a'], ['a', 'b'], ['a']]), pa.array([['x'], ['y'], ['x']]), pa.array([[2], [4], None]) ], ['value1', 'value2', 'weight']) ] schema = text_format.Parse( """ weighted_feature { name: 'weighted_feature1' feature { step: 'value1' } weight_feature { step: 'weight' } } weighted_feature { name: 'weighted_feature2' feature { step: 'value2' } weight_feature { step: 'weight' } }""", schema_pb2.Schema()) generator = (weighted_feature_stats_generator. WeightedFeatureStatsGenerator(schema)) expected_result = { types.FeaturePath(['weighted_feature1']): text_format.Parse( """ path { step: 'weighted_feature1' } custom_stats { name: 'missing_weight' num: 1.0 } custom_stats { name: 'missing_value' num: 0.0 } custom_stats { name: 'min_weight_length_diff' num: -1.0 } custom_stats { name: 'max_weight_length_diff' num: 0.0 }""", statistics_pb2.FeatureNameStatistics()), types.FeaturePath(['weighted_feature2']): text_format.Parse( """ path { step: 'weighted_feature2' } custom_stats { name: 'missing_weight' num: 1.0 } custom_stats { name: 'missing_value' num: 0.0 } custom_stats { name: 'min_weight_length_diff' num: -1.0 } custom_stats { name: 'max_weight_length_diff' num: 0.0 }""", statistics_pb2.FeatureNameStatistics()) } self.assertCombinerOutputEqual(batches, generator, expected_result)
def test_mi_regression_with_null_array(self): label_array = pa.array([[0.1], [0.2], [0.8], [0.7], [0.2], [0.3], [0.9], [0.4], [0.1], [0.0], [0.4], [0.6], [0.4], [0.8]]) # Random floats that do not map onto the label terrible_feat_array = pa.array([[0.4], [0.1], [0.4], [0.4], [0.8], [0.7], [0.2], [0.1], [0.0], [0.4], [0.8], [0.2], [0.5], [0.1]]) null_array = pa.array([None] * 14, type=pa.null()) # Note: It is possible to get different results for py2 and py3, depending # on the feature name used (e.g., if use 'empty_feature', the results # differ). This might be due to the scikit learn function used to compute MI # adding a small amount of noise to continuous features before computing MI. batch = pa.Table.from_arrays( [label_array, label_array, terrible_feat_array, null_array], [ "label_key", "perfect_feature", "terrible_feature", "values_empty_feature" ]) schema = text_format.Parse( """ feature { name: "values_empty_feature" type: FLOAT shape { dim { size: 1 } } } feature { name: "perfect_feature" type: FLOAT shape { dim { size: 1 } } } feature { name: "terrible_feature" type: FLOAT shape { dim { size: 1 } } } feature { name: "label_key" type: FLOAT shape { dim { size: 1 } } } """, schema_pb2.Schema()) expected = text_format.Parse( """ features { path { step: "perfect_feature" } custom_stats { name: "sklearn_adjusted_mutual_information" num: 1.0742656 } custom_stats { name: "sklearn_mutual_information" num: 1.2277528 } } features { path { step: "terrible_feature" } custom_stats { name: "sklearn_adjusted_mutual_information" num: 0.0392891 } custom_stats { name: "sklearn_mutual_information" num: 0.0392891 } } features { path { step: "values_empty_feature" } custom_stats { name: "sklearn_adjusted_mutual_information" num: 0.0 } custom_stats { name: "sklearn_mutual_information" num: 0.0 } }""", statistics_pb2.DatasetFeatureStatistics()) self._assert_mi_output_equal(batch, expected, schema, types.FeaturePath(["label_key"]))